diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 30a3286d543ac8ec5f165e9266745e4c2d9732f8..1d3f35dac16c5fde4ba03cad033891fef952e93e 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -139,12 +139,13 @@ class TrainModel(object): Setup train and val dataset instance with the corresponding data split configuration. Simultaneously, sequence_length is attached to the hyperparameter dictionary. """ + VideoDataset = datasets.get_dataset_class(self.dataset) self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) # ML/BG 2021-06-15: Is the following needed? - # self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) + self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) def setup_model(self): """