From 4416932e231b3f3bf1da170fd65b390584ffb48c Mon Sep 17 00:00:00 2001 From: gong1 <b.gong@fz-juelich.de> Date: Mon, 21 Jun 2021 11:07:37 +0200 Subject: [PATCH] Add back get_dataset_class and update the sequence_length from dataset object --- video_prediction_tools/main_scripts/main_train_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 30a3286d..1d3f35da 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -139,12 +139,13 @@ class TrainModel(object): Setup train and val dataset instance with the corresponding data split configuration. Simultaneously, sequence_length is attached to the hyperparameter dictionary. """ + VideoDataset = datasets.get_dataset_class(self.dataset) self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) # ML/BG 2021-06-15: Is the following needed? - # self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) + self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) def setup_model(self): """ -- GitLab