From b69583390856a37ad1d7a2cf6cee07b19a0ba60a Mon Sep 17 00:00:00 2001 From: "b.gong" <b.gong@fz-juelich.de> Date: Mon, 8 Jun 2020 10:07:24 +0200 Subject: [PATCH] Solved issue #56 --- Zam347_scripts/generate_era5.sh | 2 +- Zam347_scripts/train_era5.sh | 2 +- hparams/era5/vae/model_hparams.json | 8 -------- scripts/generate_transfer_learning_finetune.py | 2 +- video_prediction/layers/BasicConvLSTMCell.py | 4 ++-- video_prediction/models/__init__.py | 8 +++++--- video_prediction/models/vanilla_convLSTM_model.py | 10 +++++----- 7 files changed, 15 insertions(+), 21 deletions(-) delete mode 100644 hparams/era5/vae/model_hparams.json diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh index 72046611..1275aa65 100755 --- a/Zam347_scripts/generate_era5.sh +++ b/Zam347_scripts/generate_era5.sh @@ -3,7 +3,7 @@ python -u ../scripts/generate_transfer_learning_finetune.py \ --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords \ ---dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM \ +--dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/mcnet \ --mode test --results_dir /home/${USER}/results/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500 \ --batch_size 2 --dataset era5 > generate_era5-out.out diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh index 1f037f6f..38c398a2 100755 --- a/Zam347_scripts/train_era5.sh +++ b/Zam347_scripts/train_era5.sh @@ -2,5 +2,5 @@ -python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model convLSTM --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM +python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model mcnet --model_hparams_dict ../hparams/era5/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/mcnet #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/hparams/era5/vae/model_hparams.json b/hparams/era5/vae/model_hparams.json deleted file mode 100644 index 75e66a11..00000000 --- a/hparams/era5/vae/model_hparams.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "batch_size": 8, - "lr": 0.001, - "nz": 16, - "max_steps":500 -} - - diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index 2a9245ab..331559f6 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -88,7 +88,7 @@ def main(): parser.add_argument("--gif_length", type = int, help = "default is sequence_length") parser.add_argument("--fps", type = int, default = 4) - parser.add_argument("--gpu_mem_frac", type = float, default = 0, help = "fraction of gpu memory to use") + parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use") parser.add_argument("--seed", type = int, default = 7) args = parser.parse_args() diff --git a/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction/layers/BasicConvLSTMCell.py index 6d8defc2..321f6cc7 100644 --- a/video_prediction/layers/BasicConvLSTMCell.py +++ b/video_prediction/layers/BasicConvLSTMCell.py @@ -79,9 +79,9 @@ class BasicConvLSTMCell(ConvRNNCell): def output_size(self): return self._num_units - def __call__(self, inputs, state, scope=None): + def __call__(self, inputs, state, scope=None,reuse=None): """Long short-term memory cell (LSTM).""" - with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" + with tf.variable_scope(scope or type(self).__name__,reuse=reuse): # "BasicLSTMCell" # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state diff --git a/video_prediction/models/__init__.py b/video_prediction/models/__init__.py index 4103a236..6d7323f3 100644 --- a/video_prediction/models/__init__.py +++ b/video_prediction/models/__init__.py @@ -9,7 +9,7 @@ from .sna_model import SNAVideoPredictionModel from .sv2p_model import SV2PVideoPredictionModel from .vanilla_vae_model import VanillaVAEVideoPredictionModel from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel - +from .mcnet_model import McNetVideoPredictionModel def get_model_class(model): model_mappings = { 'ground_truth': 'GroundTruthVideoPredictionModel', @@ -19,8 +19,10 @@ def get_model_class(model): 'sna': 'SNAVideoPredictionModel', 'sv2p': 'SV2PVideoPredictionModel', 'vae': 'VanillaVAEVideoPredictionModel', - 'convLSTM': 'VanillaConvLstmVideoPredictionModel' - } + 'convLSTM': 'VanillaConvLstmVideoPredictionModel', + 'mcnet': 'McNetVideoPredictionModel', + + } model_class = model_mappings.get(model, model) model_class = globals().get(model_class) if model_class is None or not issubclass(model_class, BaseVideoPredictionModel): diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py index 225d4a54..e7753004 100644 --- a/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction/models/vanilla_convLSTM_model.py @@ -44,12 +44,12 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): batch_size: batch size for training. lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. - end_lr: learning rate for steps >= end_decay_step if decay_steps - is non-zero, ignored otherwise. - decay_steps: (decay_step, end_decay_step) tuple. + + + max_steps: number of training steps. - beta1: momentum term of Adam. - beta2: momentum term of Adam. + + context_frames: the number of ground-truth frames to pass in at start. Must be specified during instantiation. sequence_length: the number of frames in the video sequence, -- GitLab