From dd78d6803917d044d138f9247bb9341f81b07a44 Mon Sep 17 00:00:00 2001 From: gong1 <b.gong@fz-juelich.de> Date: Mon, 10 Aug 2020 14:56:49 +0200 Subject: [PATCH] update vae model --- video_prediction_savp/scripts/train_dummy.py | 9 +++++++-- .../models/vanilla_vae_model.py | 17 ++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py index 468f929c..d4db02c3 100644 --- a/video_prediction_savp/scripts/train_dummy.py +++ b/video_prediction_savp/scripts/train_dummy.py @@ -287,14 +287,17 @@ def main(): #fetches["latent_loss"] = model.latent_loss fetches["summary"] = model.summary_op - if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel": fetches["global_step"] = model.global_step fetches["total_loss"] = model.total_loss #fetch the specific loss function only for mcnet if model.__class__.__name__ == "McNetVideoPredictionModel": fetches["L_p"] = model.L_p fetches["L_gdl"] = model.L_gdl - fetches["L_GAN"] =model.L_GAN + fetches["L_GAN"] =model.L_GAN + if model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + fetches["latent_loss"] = model.latent_loss + fetches["recon_loss"] = model.recon_loss results = sess.run(fetches) train_losses.append(results["total_loss"]) #Fetch losses for validation data @@ -333,6 +336,8 @@ def main(): print ("Total_loss:{}".format(results["total_loss"])) elif model.__class__.__name__ == "SAVPVideoPredictionModel": print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"])) + elif model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"])) else: print ("The model name does not exist") diff --git a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py index 3e44b5ed..a0041107 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py @@ -21,6 +21,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) self.mode = mode self.learning_rate = self.hparams.lr + self.weight_recon = self.hparams.weight_recon self.nz = self.hparams.nz self.aggregate_nccl=aggregate_nccl self.gen_images_enc = None @@ -29,6 +30,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): self.recon_loss = None self.latent_loss = None self.total_loss = None + def get_default_hparams_dict(self): """ @@ -67,11 +69,12 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): nz=10, context_frames=-1, sequence_length=-1, - clip_length=10, #Bing: TODO What is the clip_length, original is 10, + weight_recon = 0.4 + ) return dict(itertools.chain(default_hparams.items(), hparams.items())) - def build_graph(self,x) + def build_graph(self,x): self.x = x["images"] #self.global_step = tf.Variable(0, name = 'global_step', trainable = False) self.global_step = tf.train.get_or_create_global_step() @@ -82,7 +85,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): 1 + self.z_log_sigma_sq - tf.square(self.z_mu) - tf.exp(self.z_log_sigma_sq), axis=1) self.latent_loss = tf.reduce_mean(latent_loss) - self.total_loss = self.recon_loss + self.latent_loss + self.total_loss = self.weight_recon * self.recon_loss + self.latent_loss self.train_op = tf.train.AdamOptimizer( learning_rate = self.learning_rate).minimize(self.total_loss, global_step=self.global_step) # Build a saver @@ -109,7 +112,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): @staticmethod def vae_arc3(x,l_name=0,nz=16): seq_name = "sq_" + str(l_name) + "_" - conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") @@ -121,12 +123,9 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) - conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, - seq_name + "decode_5") - conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, - seq_name + "decode_6") + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name + "decode_5") + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8,seq_name + "decode_6") x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") - return x_hat, z_mu, z_log_sigma_sq, z def vae_arc_all(self): -- GitLab