From dd78d6803917d044d138f9247bb9341f81b07a44 Mon Sep 17 00:00:00 2001
From: gong1 <b.gong@fz-juelich.de>
Date: Mon, 10 Aug 2020 14:56:49 +0200
Subject: [PATCH] update vae model

---
 video_prediction_savp/scripts/train_dummy.py    |  9 +++++++--
 .../models/vanilla_vae_model.py                 | 17 ++++++++---------
 2 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py
index 468f929c..d4db02c3 100644
--- a/video_prediction_savp/scripts/train_dummy.py
+++ b/video_prediction_savp/scripts/train_dummy.py
@@ -287,14 +287,17 @@ def main():
             #fetches["latent_loss"] = model.latent_loss
             fetches["summary"] = model.summary_op 
             
-            if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+            if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
                 fetches["global_step"] = model.global_step
                 fetches["total_loss"] = model.total_loss
                 #fetch the specific loss function only for mcnet
                 if model.__class__.__name__ == "McNetVideoPredictionModel":
                     fetches["L_p"] = model.L_p
                     fetches["L_gdl"] = model.L_gdl
-                    fetches["L_GAN"]  =model.L_GAN                    
+                    fetches["L_GAN"]  =model.L_GAN
+                if model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
+                    fetches["latent_loss"] = model.latent_loss
+                    fetches["recon_loss"] = model.recon_loss
                 results = sess.run(fetches)
                 train_losses.append(results["total_loss"])
                 #Fetch losses for validation data
@@ -333,6 +336,8 @@ def main():
                 print ("Total_loss:{}".format(results["total_loss"]))
             elif model.__class__.__name__ == "SAVPVideoPredictionModel":
                 print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"]))
+            elif model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
+                print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"]))
             else:
                 print ("The model name does not exist")
 
diff --git a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py
index 3e44b5ed..a0041107 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py
@@ -21,6 +21,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
         super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
         self.mode = mode
         self.learning_rate = self.hparams.lr
+        self.weight_recon = self.hparams.weight_recon
         self.nz = self.hparams.nz
         self.aggregate_nccl=aggregate_nccl
         self.gen_images_enc = None
@@ -29,6 +30,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
         self.recon_loss = None
         self.latent_loss = None
         self.total_loss = None
+        
 
     def get_default_hparams_dict(self):
         """
@@ -67,11 +69,12 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
             nz=10,
             context_frames=-1,
             sequence_length=-1,
-            clip_length=10, #Bing: TODO What is the clip_length, original is 10,
+            weight_recon = 0.4        
+            
         )
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
 
-    def build_graph(self,x)  
+    def build_graph(self,x):
         self.x = x["images"]
         #self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
         self.global_step = tf.train.get_or_create_global_step()
@@ -82,7 +85,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
             1 + self.z_log_sigma_sq - tf.square(self.z_mu) -
             tf.exp(self.z_log_sigma_sq), axis=1)
         self.latent_loss = tf.reduce_mean(latent_loss)
-        self.total_loss = self.recon_loss + self.latent_loss
+        self.total_loss = self.weight_recon * self.recon_loss + self.latent_loss
         self.train_op = tf.train.AdamOptimizer(
             learning_rate = self.learning_rate).minimize(self.total_loss, global_step=self.global_step)
         # Build a saver
@@ -109,7 +112,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
     @staticmethod
     def vae_arc3(x,l_name=0,nz=16):
         seq_name = "sq_" + str(l_name) + "_"
-
         conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
         conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")
         conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")
@@ -121,12 +123,9 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
         z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps        
         z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") 
         z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
-        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8,
-                                        seq_name + "decode_5")  
-        conv6  = ld.transpose_conv_layer(conv5, 3, 1, 8,
-                                        seq_name + "decode_6")
+        conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name + "decode_5")  
+        conv6  = ld.transpose_conv_layer(conv5, 3, 1, 8,seq_name + "decode_6")
         x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")
-
         return x_hat, z_mu, z_log_sigma_sq, z
 
     def vae_arc_all(self):
-- 
GitLab