From 1de49cd46e3097ae1e117dbe6b002f814f04885f Mon Sep 17 00:00:00 2001
From: Michael <m.langguth@fz-juelich.de>
Date: Thu, 20 May 2021 16:26:42 +0200
Subject: [PATCH] Update hyperparameter dictionary with sequnece_length from
 dataset-object in main_train_models.py.

---
 .../main_scripts/main_train_models.py               | 13 ++++++-------
 1 file changed, 6 insertions(+), 7 deletions(-)

diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 069e8ec9..c02378a7 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -136,23 +136,22 @@ class TrainModel(object):
                 
     def setup_dataset(self):
         """
-        Setup train and val dataset instance with the corresponding data split configuration
+        Setup train and val dataset instance with the corresponding data split configuration.
+        Simultaneously, sequence_length is attached to the hyperparameter dictionary.
         """
         VideoDataset = datasets.get_dataset_class(self.dataset)
         self.train_dataset = VideoDataset(input_dir=self.input_dir,mode='train',datasplit_config=self.datasplit_dict)
         self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val',datasplit_config=self.datasplit_dict)
-        #self.variable_scope = tf.get_variable_scope()
-        #self.variable_scope.set_use_resource(True)
-      
+
+        self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
 
     def setup_model(self):
         """
         Set up model instance for the given model names
         """
         VideoPredictionModel = models.get_model_class(self.model)
-        self.video_model = VideoPredictionModel(
-                                    hparams_dict=self.model_hparams_dict_load,
-                                       )
+        self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load)
+
     def setup_graph(self):
         """
         build model graph
-- 
GitLab