diff --git a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json index 878c29a0553ddb74a563299a7d3ec5683469194c..05a3d68f0b5cac92d3574d5bc4585370aef42973 100644 --- a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json @@ -3,9 +3,9 @@ "batch_size": 4, "lr": 0.001, "max_epochs":20, - "context_frames":10, + "context_frames":12, "loss_fun":"rmse", - "shuffle_on_val":false + "shuffle_on_val":true } diff --git a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json index 0b3788d726fc91e3d5c1aec98166259a2e0012e9..bc5f8983a5aa6b0b2ba3d560bc4c2391995794a4 100644 --- a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json @@ -2,8 +2,8 @@ { "batch_size": 10, "lr": 0.001, - "max_epochs":2, - "context_frames":10 + "max_epochs": 2, + "context_frames": 12 } diff --git a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json index 8e96727e95f761efc170365abd4e8af89696c168..770f9ff516a630ff031b94bb2c8a2b41c1686eec 100644 --- a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json @@ -11,5 +11,5 @@ "state_weight": 0.0, "nz": 32, "max_epochs":2, - "context_frames":10 + "context_frames":12 } diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json index d182658a2161a0405d9fb92d9677fc34bd39251f..6bff9c260a5acdce69cbe57c8fb9488162cdb5f2 100644 --- a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json @@ -13,7 +13,7 @@ "state_weight": 0.0, "nz": 16, "max_epochs":2, - "context_frames":10 + "context_frames": 12 } diff --git a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json index 2dcecd346b9b4adc4f3179020d0ee83b8512c6a0..1306627e24bec0888600fb88fcaa937e5f01dbd7 100644 --- a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json @@ -4,10 +4,10 @@ "lr": 0.001, "nz":16, "max_epochs":2, - "context_frames":10, + "context_frames":12, "weight_recon":1, "loss_fun": "rmse", - "shuffle_on_val": false + "shuffle_on_val": true } diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 751506d4b65df80cc4489f05054cedb9c0799ba8..30a3286d543ac8ec5f165e9266745e4c2d9732f8 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -139,11 +139,12 @@ class TrainModel(object): Setup train and val dataset instance with the corresponding data split configuration. Simultaneously, sequence_length is attached to the hyperparameter dictionary. """ - VideoDataset = datasets.get_dataset_class(self.dataset) - self.train_dataset = VideoDataset(input_dir=self.input_dir,mode='train',datasplit_config=self.datasplit_dict) - self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val',datasplit_config=self.datasplit_dict) - - self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) + self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict, + hparams_dict_config=self.model_hparams_dict) + self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, + hparams_dict_config=self.model_hparams_dict) + # ML/BG 2021-06-15: Is the following needed? + # self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) def setup_model(self): """