From aa6d9dd6128dd71fced347abd74eb473b6f74a5a Mon Sep 17 00:00:00 2001
From: Michael <m.langguth@fz-juelich.de>
Date: Wed, 3 Feb 2021 15:42:37 +0100
Subject: [PATCH] Bugfix in time tracking (for training on multiple GPUs).

---
 .../main_scripts/main_train_models.py         | 31 +++++++++----------
 1 file changed, 15 insertions(+), 16 deletions(-)

diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index f61c2f16..0b1abfa4 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -289,22 +289,21 @@ class TrainModel(object):
         with open(cnode_file, "w") as fjs:
             json.dump({"worker{0}".format(str(hvd.local_rank())): host}, fjs)
 
-    @staticmethod
-    def save_timing_to_pkl(total_time, training_time, time_per_iteration, output_dir):
+    def save_timing_to_pkl(self, training_time, time_per_iteration):
         """
         Saves tracked time per iteration step, training time and total time to pickle-file
-        :param total_time: tracked total time
         :param training_time: tracked training time
         :param time_per_iteration: tracked iteration step time (list)
-        :param output_dir: path to directory where the pickle-files will be stored
         :return: -
         """
-        with open(os.path.join(output_dir, "timing_total_time.pkl"), "wb") as f:
-            pkl.dump(total_time, f)
-        with open(os.path.join(output_dir, "timing_training_time.pkl"), "wb") as f:
+        with open(os.path.join(self.output_dir, "timing_total_time.pkl"), "wb") as f:
+            pkl.dump(time.time() - self.start_time, f)
+        with open(os.path.join(self.output_dir, "timing_training_time.pkl"), "wb") as f:
             pkl.dump(training_time, f)
-        with open(os.path.join(output_dir, "timing_per_iteration_time.pkl"), "wb") as f:
+        with open(os.path.join(self.output_dir, "timing_per_iteration_time.pkl"), "wb") as f:
             pkl.dump(time_per_iteration, f)
+        with open(os.path.join(self.output_dir, "total_steps.pkl"), "wb") as f:
+            pkl.dump(self.total_steps, f)
 
     def train_model(self):
         """
@@ -354,15 +353,15 @@ class TrainModel(object):
                         TrainModel.save_results_to_pkl(train_losses,val_losses,self.output_dir)
                         TrainModel.plot_train(train_losses,val_losses,step,self.output_dir)
 
-            # track time (save to pickle-files)
-            train_time = time.time() - run_start_time   #Total train time over all the iterations
-            total_run_time = time.time() - self.start_time
+            if hvd.rank() == 0:
+                # track time (save to pickle-files)
+                train_time = time.time() - run_start_time   #Total train time over all the iterations
 
-            TrainModel.save_timing_to_pkl(total_run_time, train_time, time_per_iteration, self.output_dir)
-            # create result dictionary and save it
-            results_dict = {"train_time": train_time,
-                            "total_steps": self.total_steps}
-            TrainModel.save_results_to_dict(results_dict, self.output_dir)
+                TrainModel.save_timing_to_pkl(self, train_time, time_per_iteration)
+                # create result dictionary and save it
+                #results_dict = {"train_time": train_time,
+                #                "total_steps": self.total_steps}
+                #TrainModel.save_results_to_dict(results_dict, self.output_dir)
             # print some diagnostics
             print("train_losses:",train_losses)
             print("val_losses:",val_losses) 
-- 
GitLab