diff --git a/video_prediction_savp/HPC_scripts/train_era5.sh b/video_prediction_savp/HPC_scripts/train_era5.sh index f605866056f6b2d9fa179a00850468fee0c72d87..5173564faae730cda10ac3acc072fe9ed43cb7b3 100755 --- a/video_prediction_savp/HPC_scripts/train_era5.sh +++ b/video_prediction_savp/HPC_scripts/train_era5.sh @@ -37,7 +37,7 @@ fi source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/ destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/ -# for choosing the model +# for choosing the model for choosing the model, convLSTM,savp, mcnet,vae model=convLSTM model_hparams=../hparams/era5/${model}/model_hparams.json diff --git a/video_prediction_savp/hparams/era5/mcnet/model_hparams.json b/video_prediction_savp/hparams/era5/mcnet/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8 --- /dev/null +++ b/video_prediction_savp/hparams/era5/mcnet/model_hparams.json @@ -0,0 +1,12 @@ + +{ + "batch_size": 10, + "lr": 0.001, + "max_epochs":2, + "context_frames":10, + "sequence_length":20 + +} + + + diff --git a/video_prediction_savp/hparams/era5/savp/model_hparams.json b/video_prediction_savp/hparams/era5/savp/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..641ffb36f764f5ae720a534d7d9eef0ebad644d8 --- /dev/null +++ b/video_prediction_savp/hparams/era5/savp/model_hparams.json @@ -0,0 +1,18 @@ +{ + "batch_size": 4, + "lr": 0.0002, + "beta1": 0.5, + "beta2": 0.999, + "l1_weight": 100.0, + "l2_weight": 0.0, + "kl_weight": 0.01, + "video_sn_vae_gan_weight": 0.1, + "video_sn_gan_weight": 0.1, + "vae_gan_feature_cdist_weight": 10.0, + "gan_feature_cdist_weight": 0.0, + "state_weight": 0.0, + "nz": 32, + "max_epochs":2 +} + + diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py index 1fb401955c39be4807cf7747e43ed660941cb925..4e30c4ce8e65799b88defa7c331d08dd0469c079 100644 --- a/video_prediction_savp/scripts/train_dummy.py +++ b/video_prediction_savp/scripts/train_dummy.py @@ -199,7 +199,7 @@ def main(): parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") - parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") + parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="fraction of gpu memory to use") parser.add_argument("--seed",default=1234, type=int) args = parser.parse_args() @@ -232,6 +232,7 @@ def main(): inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size) #build model graph + del inputs["T_start"] model.build_graph(inputs) #save all the model, data params to output dirctory @@ -255,6 +256,7 @@ def main(): print ("number of exmaples per epoch:",num_examples_per_epoch) steps_per_epoch = int(num_examples_per_epoch/batch_size) total_steps = steps_per_epoch * max_epochs + global_step = tf.train.get_or_create_global_step() #mock total_steps only for fast debugging #total_steps = 10 print ("Total steps for training:",total_steps) @@ -263,63 +265,77 @@ def main(): print("parameter_count =", sess.run(parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) - #model.restore(sess, args.checkpoint) + model.restore(sess, args.checkpoint) sess.graph.finalize() - start_step = sess.run(model.global_step) + #start_step = sess.run(model.global_step) + start_step = sess.run(global_step) print("start_step", start_step) # start at one step earlier to log everything without doing any training # step is relative to the start_step train_losses=[] val_losses=[] run_start_time = time.time() - for step in range(total_steps): - global_step = sess.run(model.global_step) - print ("global_step:", global_step) + for step in range(start_step,total_steps): + #global_step = sess.run(global_step):q + + print ("step:", step) val_handle_eval = sess.run(val_handle) - + #Fetch variables in the graph - fetches = {"global_step":model.global_step} - fetches["train_op"] = model.train_op - #fetches["latent_loss"] = model.latent_loss - fetches["total_loss"] = model.total_loss - #fetch the specific loss function only for mcnet - if model.__class__.__name__ == "McNetVideoPredictionModel": - fetches["L_p"] = model.L_p - fetches["L_gdl"] = model.L_gdl - fetches["L_GAN"] =model.L_GAN - - if model.__class__.__name__ == "SAVP": - #todo - pass + fetches = {"train_op": model.train_op} + #fetches["latent_loss"] = model.latent_loss + fetches["summary"] = model.summary_op - fetches["summary"] = model.summary_op - results = sess.run(fetches) - train_losses.append(results["total_loss"]) - #Fetch losses for validation data - val_fetches = {} - #val_fetches["latent_loss"] = model.latent_loss - val_fetches["total_loss"] = model.total_loss + if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + fetches["global_step"] = model.global_step + fetches["total_loss"] = model.total_loss + #fetch the specific loss function only for mcnet + if model.__class__.__name__ == "McNetVideoPredictionModel": + fetches["L_p"] = model.L_p + fetches["L_gdl"] = model.L_gdl + fetches["L_GAN"] =model.L_GAN + results = sess.run(fetches) + train_losses.append(results["total_loss"]) + #Fetch losses for validation data + val_fetches = {} + #val_fetches["latent_loss"] = model.latent_loss + val_fetches["total_loss"] = model.total_loss + + + if model.__class__.__name__ == "SAVPVideoPredictionModel": + fetches['d_loss'] = model.d_loss + fetches['g_loss'] = model.g_loss + fetches['d_losses'] = model.d_losses + fetches['g_losses'] = model.g_losses + results = sess.run(fetches) + train_losses.append(results["g_losses"]) + val_fetches = {} + #val_fetches["latent_loss"] = model.latent_loss + #For SAVP the total loss is the generator loses + val_fetches["total_loss"] = model.g_losses + val_fetches["summary"] = model.summary_op val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval}) val_losses.append(val_results["total_loss"]) - + summary_writer.add_summary(results["summary"]) summary_writer.add_summary(val_results["summary"]) summary_writer.flush() - + # global_step will have the correct step count if we resume from a checkpoint # global step is read before it's incemented - train_epoch = global_step/steps_per_epoch - print("progress global step %d epoch %0.1f" % (global_step + 1, train_epoch)) - + train_epoch = step/steps_per_epoch + print("progress global step %d epoch %0.1f" % (step + 1, train_epoch)) if model.__class__.__name__ == "McNetVideoPredictionModel": - print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) + print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) elif model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": print ("Total_loss:{}".format(results["total_loss"])) + elif model.__class__.__name__ == "SAVPVideoPredictionModel": + print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"])) else: print ("The model name does not exist") - + #print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)# train_time = time.time() - run_start_time diff --git a/video_prediction_savp/video_prediction/models/base_model.py b/video_prediction_savp/video_prediction/models/base_model.py index 0ebe228fcc9c90addf610bed44bb46f090c7e514..846621d8ca1e235c39618951be86fe184a2d974d 100644 --- a/video_prediction_savp/video_prediction/models/base_model.py +++ b/video_prediction_savp/video_prediction/models/base_model.py @@ -366,7 +366,7 @@ class VideoPredictionModel(BaseVideoPredictionModel): end_lr=0.0, decay_steps=(200000, 300000), lr_boundaries=(0,), - max_steps=350000, + max_epochs=35, beta1=0.9, beta2=0.999, context_frames=-1, diff --git a/video_prediction_savp/video_prediction/models/mcnet_model.py b/video_prediction_savp/video_prediction/models/mcnet_model.py index 725ce4f46a301b6aa07f3d50ef811584d5b502db..7a376cb7b2ddb4f46b3ad67a6b2cf7e866823427 100644 --- a/video_prediction_savp/video_prediction/models/mcnet_model.py +++ b/video_prediction_savp/video_prediction/models/mcnet_model.py @@ -72,7 +72,7 @@ class McNetVideoPredictionModel(BaseVideoPredictionModel): hparams = dict( batch_size=16, lr=0.001, - max_steps=350000, + max_epochs=350000, context_frames = 10, sequence_length = 20, nz = 16, @@ -96,7 +96,8 @@ class McNetVideoPredictionModel(BaseVideoPredictionModel): self.is_train = True - self.global_step = tf.Variable(0, name='global_step', trainable=False) + #self.global_step = tf.Variable(0, name='global_step', trainable=False) + self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() # self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt') diff --git a/video_prediction_savp/video_prediction/models/savp_model.py b/video_prediction_savp/video_prediction/models/savp_model.py index ca8acd3f32a5ea1772c9fbf36003149acfdcb950..c510d050c89908d0e06fe0f1a66e355e61c90530 100644 --- a/video_prediction_savp/video_prediction/models/savp_model.py +++ b/video_prediction_savp/video_prediction/models/savp_model.py @@ -688,6 +688,7 @@ class SAVPCell(tf.nn.rnn_cell.RNNCell): def generator_given_z_fn(inputs, mode, hparams): # all the inputs needs to have the same length for unrolling the rnn + print("inputs.items",inputs.items()) inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} cell = SAVPCell(inputs, mode, hparams) diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index 7e3fec28dc28c78b8203e1924f17489af8f5075e..6ca386fcda740b7a3da1a16d0ad84dcd08fe653a 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -60,8 +60,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): def build_graph(self, x): self.x = x["images"] - - self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + #self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() # ARCHITECTURE self.convLSTM_network() diff --git a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py index eec5598305044226280080d630313487c7d847a4..81c556cea556aa4a7415f33ae4de817023c89d9b 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py @@ -63,7 +63,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): end_lr=0.0, decay_steps=(200000, 300000), lr_boundaries=(0,), - max_steps=350000, + max_epochs=35, nz=10, context_frames=-1, sequence_length=-1, @@ -71,42 +71,16 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ) return dict(itertools.chain(default_hparams.items(), hparams.items())) - def build_graph(self,x): - - - - - - - tf.set_random_seed(12345) + def build_graph(self,x) self.x = x["images"] - - - self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + #self.global_step = tf.Variable(0, name = 'global_step', trainable = False) + self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() - self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step') - self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all() - - - - - - - - - - - self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0])) - - - - - latent_loss = -0.5 * tf.reduce_sum( 1 + self.z_log_sigma_sq - tf.square(self.z_mu) - - tf.exp(self.z_log_sigma_sq), axis = 1) + tf.exp(self.z_log_sigma_sq), axis=1) self.latent_loss = tf.reduce_mean(latent_loss) self.total_loss = self.recon_loss + self.latent_loss self.train_op = tf.train.AdamOptimizer( @@ -125,51 +99,33 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss) self.summary_op = tf.summary.merge_all() - - self.outputs = {} self.outputs["gen_images"] = self.x_hat global_variables = [var for var in tf.global_variables() if var not in original_global_variables] self.saveable_variables = [self.global_step] + global_variables - return + return None @staticmethod def vae_arc3(x,l_name=0,nz=16): seq_name = "sq_" + str(l_name) + "_" - - conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") - + conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") - - conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") - - conv4 = tf.layers.Flatten()(conv3) - conv3_shape = conv3.get_shape().as_list() - - z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m") z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma') eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32) - z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps - - z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") - - + z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps + z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) - conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name + "decode_5") - conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, seq_name + "decode_6") - - x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") return x_hat, z_mu, z_log_sigma_sq, z @@ -186,6 +142,5 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): x_hat = tf.stack(X, axis = 1) z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1) z_mu_all = tf.stack(z_mu_all, axis = 1) - - + return x_hat, z_log_sigma_sq_all, z_mu_all