From c42524c68c4a5a8035a4abe524fd0e37c66320c6 Mon Sep 17 00:00:00 2001 From: Timo Tjaden Stomberg <timo.stomberg@uni-bonn.de> Date: Tue, 30 Aug 2022 12:37:01 +0200 Subject: [PATCH] changes from martin to allow the use of cpu --- projects/asos/config.py | 1 + projects/asos/utils.py | 2 +- projects/main_config.py | 14 ++++++++++++-- tlib/ttorch/model.py | 5 +++-- tlib/ttorch/train.py | 5 +++-- 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/projects/asos/config.py b/projects/asos/config.py index 8218d21..5dcf7a9 100644 --- a/projects/asos/config.py +++ b/projects/asos/config.py @@ -13,6 +13,7 @@ dataset = 'anthroprotect' # anthroprotect, places # parameters batch_size = 32 max_image_size = 2048 +device = main_config.device num_workers = main_config.num_workers random_seed = 0 diff --git a/projects/asos/utils.py b/projects/asos/utils.py index aea982d..d810140 100644 --- a/projects/asos/utils.py +++ b/projects/asos/utils.py @@ -64,7 +64,7 @@ def load_trainer(): :return: loaded trainer object """ - return ttorch.train.ClassTrainer().load(log_dir=config.log_path) + return ttorch.train.ClassTrainer().load(log_dir=config.log_path, device=config.device) def load_asos(): diff --git a/projects/main_config.py b/projects/main_config.py index ee9ee9b..9fb7e00 100644 --- a/projects/main_config.py +++ b/projects/main_config.py @@ -1,5 +1,7 @@ import os +import torch + # ------------------------------------------------------------------------------------------------------------ # Please check the following configurations: @@ -17,9 +19,11 @@ data_folder = os.path.expanduser('~/data/anthroprotect') # Define the number of workers to load data while training the model and running model predictions. num_workers = 8 -# You might want to run the following lines on your system to avoid abortion while training the model. +# device: 'cuda', 'cuda:<cuda_id>' or 'cpu' +device = 'cuda' + +# You might want to run the following line on your system to avoid abortion while training the model. # Read more at: https://pytorch.org/docs/stable/multiprocessing.html#file-descriptor-file-descriptor -import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') # ------------------------------------------------------------------------------------------------------------ @@ -40,3 +44,9 @@ places365_file_infos_path_raw = os.path.join(working_dir, 'file_infos.csv') # If you want to change the name of the 'logs' folder in which logging files etc. are stored, you need to change this path: log_path = os.path.join(working_dir, 'logs') + + +# Raise warnings: + +if device.startswith('cuda') and not torch.cuda.is_available(): + raise Exception("WARNING: cuda is selected as device but not available on your machine.") diff --git a/tlib/ttorch/model.py b/tlib/ttorch/model.py index 2e41155..63b7202 100644 --- a/tlib/ttorch/model.py +++ b/tlib/ttorch/model.py @@ -94,16 +94,17 @@ class Module(nn.Module): state_changer.reverse(model=self.model) # return to train mode if model was in that mode before training - def load(self, path): + def load(self, path, device='cuda'): """ Load model from given path using model_state_dict within dictionary into this object. Infos for loading models with pytorch: https://pytorch.org/tutorials/beginner/saving_loading_models.html :param path: path to model checkpoint, which must be a dictionary with entry 'model_state_dict' + :param device: available and to be used device ('cuda', 'cuda:<cuda_id>' or 'cpu'; default: 'cuda') :return: self (model); but also overwrites own __dict__ """ - checkpoint = torch.load(path) + checkpoint = torch.load(path, map_location=torch.device(device)) model_dict = checkpoint['model_state_dict'] self.__dict__.clear() diff --git a/tlib/ttorch/train.py b/tlib/ttorch/train.py index e822577..e95f8aa 100644 --- a/tlib/ttorch/train.py +++ b/tlib/ttorch/train.py @@ -921,7 +921,7 @@ class Trainer: state_changer.reverse(model=self.model) # return to train mode if model was in that mode before training - def load(self, log_dir, dummy_model=None, dummy_optimizer=None): + def load(self, log_dir, dummy_model=None, dummy_optimizer=None, device='cuda'): """ Load trainer from given path. Infos how to save and load pytorch modules: https://pytorch.org/tutorials/beginner/saving_loading_models.html @@ -929,13 +929,14 @@ class Trainer: :param log_dir: logging directory :param dummy_model: model object to load state_dict of model with :param dummy_optimizer: optimizer object to load state_dict of optimizer with + :param device: available and to be used device ('cuda', 'cuda:<cuda_id>' or 'cpu'; default: 'cuda') :return: self """ checkpoint_path = ttorch.utils.get_last_checkpoint(log_dir=log_dir) # load checkpoint - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) # overwrite self self.__dict__.clear() -- GitLab