diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh index 507ccf9cb2c30b50a5cb768125ed3e472b989a95..72046611bc0e35aa297b73266aa9c2e89c0101b8 100755 --- a/Zam347_scripts/generate_era5.sh +++ b/Zam347_scripts/generate_era5.sh @@ -3,7 +3,7 @@ python -u ../scripts/generate_transfer_learning_finetune.py \ --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords \ ---dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/vae \ +--dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM \ --mode test --results_dir /home/${USER}/results/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500 \ --batch_size 2 --dataset era5 > generate_era5-out.out diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh index b93c7bea814d41b2255f1201b430c37a2022db4e..1f037f6fc21ac0e21a1e16ba5b6dc62438dda13a 100755 --- a/Zam347_scripts/train_era5.sh +++ b/Zam347_scripts/train_era5.sh @@ -2,5 +2,5 @@ -python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model vae --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/vae +python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model convLSTM --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/hparams/era5/vae/model_hparams.json b/hparams/era5/vae/model_hparams.json index 2e9406148e140054ced5e0c4311f3885aa47f728..75e66a11a15fa462abbc113ef76253fb6d15eca6 100644 --- a/hparams/era5/vae/model_hparams.json +++ b/hparams/era5/vae/model_hparams.json @@ -1,8 +1,8 @@ { "batch_size": 8, - "lr": 0.0002, + "lr": 0.001, "nz": 16, - "max_steps":20 + "max_steps":500 } diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py index 6cb07df7f7cb72fa0943299adbc18a7641636521..8cd2ad3f2b99e9a88c9471db2c0dc6f4ccb89913 100644 --- a/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction/models/vanilla_convLSTM_model.py @@ -88,7 +88,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): self.train_op = tf.train.AdamOptimizer( learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) - + self.outputs = {} + self.outputs["gen_images"] = self.x_hat # Summary op self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss) self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss)