diff --git a/Zam347_scripts/generate_era5.sh b/Zam347_scripts/generate_era5.sh index 72046611bc0e35aa297b73266aa9c2e89c0101b8..1275aa6503f5f08c2c47095b41b688a122e7de12 100755 --- a/Zam347_scripts/generate_era5.sh +++ b/Zam347_scripts/generate_era5.sh @@ -3,7 +3,7 @@ python -u ../scripts/generate_transfer_learning_finetune.py \ --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords \ ---dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM \ +--dataset_hparams sequence_length=20 --checkpoint /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/mcnet \ --mode test --results_dir /home/${USER}/results/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500 \ --batch_size 2 --dataset era5 > generate_era5-out.out diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh index 1f037f6fc21ac0e21a1e16ba5b6dc62438dda13a..38c398a2051b265860a6fb5acc5725e277405cab 100755 --- a/Zam347_scripts/train_era5.sh +++ b/Zam347_scripts/train_era5.sh @@ -2,5 +2,5 @@ -python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model convLSTM --model_hparams_dict ../hparams/era5/vae/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/convLSTM +python ../scripts/train_dummy.py --input_dir /home/${USER}/preprocessedData/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model mcnet --model_hparams_dict ../hparams/era5/model_hparams.json --output_dir /home/${USER}/models/era5-Y2017M01to02-128x160-74d00N71d00E-T_MSL_gph500/mcnet #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/hparams/era5/vae/model_hparams.json b/hparams/era5/vae/model_hparams.json deleted file mode 100644 index 75e66a11a15fa462abbc113ef76253fb6d15eca6..0000000000000000000000000000000000000000 --- a/hparams/era5/vae/model_hparams.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "batch_size": 8, - "lr": 0.001, - "nz": 16, - "max_steps":500 -} - - diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index 2a9245ab54e7ad72fdbf153504e2ec507d4688e2..331559f6287a4f24c1c19ee9f7f4b03309a22abf 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -88,7 +88,7 @@ def main(): parser.add_argument("--gif_length", type = int, help = "default is sequence_length") parser.add_argument("--fps", type = int, default = 4) - parser.add_argument("--gpu_mem_frac", type = float, default = 0, help = "fraction of gpu memory to use") + parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use") parser.add_argument("--seed", type = int, default = 7) args = parser.parse_args() diff --git a/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction/layers/BasicConvLSTMCell.py index 6d8defc2874ba29f177e4512bfea78a5f4298518..321f6cc7e05320cf83e1173d8004429edf07ec24 100644 --- a/video_prediction/layers/BasicConvLSTMCell.py +++ b/video_prediction/layers/BasicConvLSTMCell.py @@ -79,9 +79,9 @@ class BasicConvLSTMCell(ConvRNNCell): def output_size(self): return self._num_units - def __call__(self, inputs, state, scope=None): + def __call__(self, inputs, state, scope=None,reuse=None): """Long short-term memory cell (LSTM).""" - with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" + with tf.variable_scope(scope or type(self).__name__,reuse=reuse): # "BasicLSTMCell" # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state diff --git a/video_prediction/models/__init__.py b/video_prediction/models/__init__.py index 4103a236ab6430d701bae28ee9b6ff6670b110fa..6d7323f3750949b0ddb411d4a98934928537bc53 100644 --- a/video_prediction/models/__init__.py +++ b/video_prediction/models/__init__.py @@ -9,7 +9,7 @@ from .sna_model import SNAVideoPredictionModel from .sv2p_model import SV2PVideoPredictionModel from .vanilla_vae_model import VanillaVAEVideoPredictionModel from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel - +from .mcnet_model import McNetVideoPredictionModel def get_model_class(model): model_mappings = { 'ground_truth': 'GroundTruthVideoPredictionModel', @@ -19,8 +19,10 @@ def get_model_class(model): 'sna': 'SNAVideoPredictionModel', 'sv2p': 'SV2PVideoPredictionModel', 'vae': 'VanillaVAEVideoPredictionModel', - 'convLSTM': 'VanillaConvLstmVideoPredictionModel' - } + 'convLSTM': 'VanillaConvLstmVideoPredictionModel', + 'mcnet': 'McNetVideoPredictionModel', + + } model_class = model_mappings.get(model, model) model_class = globals().get(model_class) if model_class is None or not issubclass(model_class, BaseVideoPredictionModel): diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py index 225d4a5493158ab77dfb182f7f1a45fa5156286e..e7753004348ae0ae60057a469de1e2d1421c3869 100644 --- a/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction/models/vanilla_convLSTM_model.py @@ -44,12 +44,12 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): batch_size: batch size for training. lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. - end_lr: learning rate for steps >= end_decay_step if decay_steps - is non-zero, ignored otherwise. - decay_steps: (decay_step, end_decay_step) tuple. + + + max_steps: number of training steps. - beta1: momentum term of Adam. - beta2: momentum term of Adam. + + context_frames: the number of ground-truth frames to pass in at start. Must be specified during instantiation. sequence_length: the number of frames in the video sequence,