From c42524c68c4a5a8035a4abe524fd0e37c66320c6 Mon Sep 17 00:00:00 2001
From: Timo Tjaden Stomberg <timo.stomberg@uni-bonn.de>
Date: Tue, 30 Aug 2022 12:37:01 +0200
Subject: [PATCH] changes from martin to allow the use of cpu

---
 projects/asos/config.py |  1 +
 projects/asos/utils.py  |  2 +-
 projects/main_config.py | 14 ++++++++++++--
 tlib/ttorch/model.py    |  5 +++--
 tlib/ttorch/train.py    |  5 +++--
 5 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/projects/asos/config.py b/projects/asos/config.py
index 8218d21..5dcf7a9 100644
--- a/projects/asos/config.py
+++ b/projects/asos/config.py
@@ -13,6 +13,7 @@ dataset = 'anthroprotect'  # anthroprotect, places
 # parameters
 batch_size = 32
 max_image_size = 2048
+device = main_config.device
 num_workers = main_config.num_workers
 
 random_seed = 0
diff --git a/projects/asos/utils.py b/projects/asos/utils.py
index aea982d..d810140 100644
--- a/projects/asos/utils.py
+++ b/projects/asos/utils.py
@@ -64,7 +64,7 @@ def load_trainer():
     :return: loaded trainer object
     """
 
-    return ttorch.train.ClassTrainer().load(log_dir=config.log_path)
+    return ttorch.train.ClassTrainer().load(log_dir=config.log_path, device=config.device)
 
 
 def load_asos():
diff --git a/projects/main_config.py b/projects/main_config.py
index ee9ee9b..9fb7e00 100644
--- a/projects/main_config.py
+++ b/projects/main_config.py
@@ -1,5 +1,7 @@
 import os
 
+import torch
+
 
 # ------------------------------------------------------------------------------------------------------------
 # Please check the following configurations:
@@ -17,9 +19,11 @@ data_folder = os.path.expanduser('~/data/anthroprotect')
 # Define the number of workers to load data while training the model and running model predictions.
 num_workers = 8
 
-# You might want to run the following lines on your system to avoid abortion while training the model.
+# device: 'cuda', 'cuda:<cuda_id>' or 'cpu'
+device = 'cuda'
+
+# You might want to run the following line on your system to avoid abortion while training the model.
 # Read more at: https://pytorch.org/docs/stable/multiprocessing.html#file-descriptor-file-descriptor
-import torch.multiprocessing
 torch.multiprocessing.set_sharing_strategy('file_system')
 # ------------------------------------------------------------------------------------------------------------
 
@@ -40,3 +44,9 @@ places365_file_infos_path_raw = os.path.join(working_dir, 'file_infos.csv')
 
 # If you want to change the name of the 'logs' folder in which logging files etc. are stored, you need to change this path:
 log_path = os.path.join(working_dir, 'logs')
+
+
+# Raise warnings:
+
+if device.startswith('cuda') and not torch.cuda.is_available():
+    raise Exception("WARNING: cuda is selected as device but not available on your machine.")
diff --git a/tlib/ttorch/model.py b/tlib/ttorch/model.py
index 2e41155..63b7202 100644
--- a/tlib/ttorch/model.py
+++ b/tlib/ttorch/model.py
@@ -94,16 +94,17 @@ class Module(nn.Module):
 
         state_changer.reverse(model=self.model)  # return to train mode if model was in that mode before training
 
-    def load(self, path):
+    def load(self, path, device='cuda'):
         """
         Load model from given path using model_state_dict within dictionary into this object.
         Infos for loading models with pytorch: https://pytorch.org/tutorials/beginner/saving_loading_models.html
 
         :param path: path to model checkpoint, which must be a dictionary with entry 'model_state_dict'
+        :param device: available and to be used device ('cuda', 'cuda:<cuda_id>' or 'cpu'; default: 'cuda')
         :return: self (model); but also overwrites own __dict__
         """
 
-        checkpoint = torch.load(path)
+        checkpoint = torch.load(path, map_location=torch.device(device))
         model_dict = checkpoint['model_state_dict']
         
         self.__dict__.clear()
diff --git a/tlib/ttorch/train.py b/tlib/ttorch/train.py
index e822577..e95f8aa 100644
--- a/tlib/ttorch/train.py
+++ b/tlib/ttorch/train.py
@@ -921,7 +921,7 @@ class Trainer:
 
         state_changer.reverse(model=self.model)  # return to train mode if model was in that mode before training
 
-    def load(self, log_dir, dummy_model=None, dummy_optimizer=None):
+    def load(self, log_dir, dummy_model=None, dummy_optimizer=None, device='cuda'):
         """
         Load trainer from given path.
         Infos how to save and load pytorch modules: https://pytorch.org/tutorials/beginner/saving_loading_models.html
@@ -929,13 +929,14 @@ class Trainer:
         :param log_dir: logging directory
         :param dummy_model: model object to load state_dict of model with
         :param dummy_optimizer: optimizer object to load state_dict of optimizer with
+        :param device: available and to be used device ('cuda', 'cuda:<cuda_id>' or 'cpu'; default: 'cuda')
         :return: self
         """
 
         checkpoint_path = ttorch.utils.get_last_checkpoint(log_dir=log_dir)
 
         # load checkpoint
-        checkpoint = torch.load(checkpoint_path)
+        checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
 
         # overwrite self
         self.__dict__.clear()
-- 
GitLab