From 06e7a80e096b1c9657ce2fe315f12b91ca09d3f7 Mon Sep 17 00:00:00 2001
From: "b.gong" <b.gong@fz-juelich.de>
Date: Tue, 2 Jun 2020 19:27:52 +0200
Subject: [PATCH] Solve global_step issue

---
 Zam347_scripts/generate_era5.sh                   | 2 +-
 Zam347_scripts/train_era5.sh                      | 2 +-
 hparams/era5/vae/model_hparams.json               | 4 ++--
 video_prediction/models/vanilla_convLSTM_model.py | 3 ++-
 4 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh
index 507ccf9c..72046611 100755
--- a/Zam347_scripts/generate_era5.sh
+++ b/Zam347_scripts/generate_era5.sh
@@ -3,7 +3,7 @@
 
 python -u ../scripts/generate_transfer_learning_finetune.py \
 --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords  \
---dataset_hparams sequence_length=20 --checkpoint  /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/vae \
+--dataset_hparams sequence_length=20 --checkpoint  /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM \
 --mode test --results_dir /home/${USER}/results/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500 \
 --batch_size 2 --dataset era5   > generate_era5-out.out
 
diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh
index b93c7bea..1f037f6f 100755
--- a/Zam347_scripts/train_era5.sh
+++ b/Zam347_scripts/train_era5.sh
@@ -2,5 +2,5 @@
 
 
 
-python ../scripts/train_dummy.py --input_dir  /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5  --model vae --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/vae 
+python ../scripts/train_dummy.py --input_dir  /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5  --model convLSTM  --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM 
 #srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/hparams/era5/vae/model_hparams.json b/hparams/era5/vae/model_hparams.json
index 2e940614..75e66a11 100644
--- a/hparams/era5/vae/model_hparams.json
+++ b/hparams/era5/vae/model_hparams.json
@@ -1,8 +1,8 @@
 {
     "batch_size": 8,
-    "lr": 0.0002,
+    "lr": 0.001,
     "nz": 16,
-    "max_steps":20
+    "max_steps":500
 }
 
 
diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py
index 6cb07df7..8cd2ad3f 100644
--- a/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction/models/vanilla_convLSTM_model.py
@@ -88,7 +88,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
 
         self.train_op = tf.train.AdamOptimizer(
             learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
-
+        self.outputs = {}
+        self.outputs["gen_images"] = self.x_hat
         # Summary op
         self.loss_summary = tf.summary.scalar("recon_loss", self.context_frames_loss)
         self.loss_summary = tf.summary.scalar("latent_loss", self.predict_frames_loss)
-- 
GitLab