diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 069e8ec9ea1812d8254454b33f670209d65024d9..c02378a702f5d807210cbef890b507fc99a8c7c7 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -136,23 +136,22 @@ class TrainModel(object): def setup_dataset(self): """ - Setup train and val dataset instance with the corresponding data split configuration + Setup train and val dataset instance with the corresponding data split configuration. + Simultaneously, sequence_length is attached to the hyperparameter dictionary. """ VideoDataset = datasets.get_dataset_class(self.dataset) self.train_dataset = VideoDataset(input_dir=self.input_dir,mode='train',datasplit_config=self.datasplit_dict) self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val',datasplit_config=self.datasplit_dict) - #self.variable_scope = tf.get_variable_scope() - #self.variable_scope.set_use_resource(True) - + + self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) def setup_model(self): """ Set up model instance for the given model names """ VideoPredictionModel = models.get_model_class(self.model) - self.video_model = VideoPredictionModel( - hparams_dict=self.model_hparams_dict_load, - ) + self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load) + def setup_graph(self): """ build model graph