From 2bd88f6cbdedc6f1e6f18acb88ef786715be2916 Mon Sep 17 00:00:00 2001
From: Timo Tjaden Stomberg <timo.stomberg@uni-bonn.de>
Date: Wed, 26 Oct 2022 16:02:03 +0200
Subject: [PATCH] bug: does now work on cpu

---
 tlib/ttorch/train.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tlib/ttorch/train.py b/tlib/ttorch/train.py
index e89170a..5906e94 100644
--- a/tlib/ttorch/train.py
+++ b/tlib/ttorch/train.py
@@ -1114,7 +1114,7 @@ def load_trainer(log_dir, Class=None, datamodule_folder=None, device='cuda'):
     """
 
     checkpoint_path = ttorch.utils.get_last_checkpoint(log_dir)
-    checkpoint = torch.load(checkpoint_path)
+    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
 
     # Class
     if Class is not None:
-- 
GitLab