Skip to content
Snippets Groups Projects
Commit dd78d680 authored by gong1's avatar gong1
Browse files

update vae model

parent b4fb67da
No related branches found
No related tags found
No related merge requests found
...@@ -287,7 +287,7 @@ def main(): ...@@ -287,7 +287,7 @@ def main():
#fetches["latent_loss"] = model.latent_loss #fetches["latent_loss"] = model.latent_loss
fetches["summary"] = model.summary_op fetches["summary"] = model.summary_op
if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
fetches["global_step"] = model.global_step fetches["global_step"] = model.global_step
fetches["total_loss"] = model.total_loss fetches["total_loss"] = model.total_loss
#fetch the specific loss function only for mcnet #fetch the specific loss function only for mcnet
...@@ -295,6 +295,9 @@ def main(): ...@@ -295,6 +295,9 @@ def main():
fetches["L_p"] = model.L_p fetches["L_p"] = model.L_p
fetches["L_gdl"] = model.L_gdl fetches["L_gdl"] = model.L_gdl
fetches["L_GAN"] =model.L_GAN fetches["L_GAN"] =model.L_GAN
if model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
fetches["latent_loss"] = model.latent_loss
fetches["recon_loss"] = model.recon_loss
results = sess.run(fetches) results = sess.run(fetches)
train_losses.append(results["total_loss"]) train_losses.append(results["total_loss"])
#Fetch losses for validation data #Fetch losses for validation data
...@@ -333,6 +336,8 @@ def main(): ...@@ -333,6 +336,8 @@ def main():
print ("Total_loss:{}".format(results["total_loss"])) print ("Total_loss:{}".format(results["total_loss"]))
elif model.__class__.__name__ == "SAVPVideoPredictionModel": elif model.__class__.__name__ == "SAVPVideoPredictionModel":
print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"])) print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"]))
elif model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"]))
else: else:
print ("The model name does not exist") print ("The model name does not exist")
......
...@@ -21,6 +21,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ...@@ -21,6 +21,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs) super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
self.mode = mode self.mode = mode
self.learning_rate = self.hparams.lr self.learning_rate = self.hparams.lr
self.weight_recon = self.hparams.weight_recon
self.nz = self.hparams.nz self.nz = self.hparams.nz
self.aggregate_nccl=aggregate_nccl self.aggregate_nccl=aggregate_nccl
self.gen_images_enc = None self.gen_images_enc = None
...@@ -30,6 +31,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ...@@ -30,6 +31,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
self.latent_loss = None self.latent_loss = None
self.total_loss = None self.total_loss = None
def get_default_hparams_dict(self): def get_default_hparams_dict(self):
""" """
The keys of this dict define valid hyperparameters for instances of The keys of this dict define valid hyperparameters for instances of
...@@ -67,11 +69,12 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ...@@ -67,11 +69,12 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
nz=10, nz=10,
context_frames=-1, context_frames=-1,
sequence_length=-1, sequence_length=-1,
clip_length=10, #Bing: TODO What is the clip_length, original is 10, weight_recon = 0.4
) )
return dict(itertools.chain(default_hparams.items(), hparams.items())) return dict(itertools.chain(default_hparams.items(), hparams.items()))
def build_graph(self,x) def build_graph(self,x):
self.x = x["images"] self.x = x["images"]
#self.global_step = tf.Variable(0, name = 'global_step', trainable = False) #self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
self.global_step = tf.train.get_or_create_global_step() self.global_step = tf.train.get_or_create_global_step()
...@@ -82,7 +85,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ...@@ -82,7 +85,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
1 + self.z_log_sigma_sq - tf.square(self.z_mu) - 1 + self.z_log_sigma_sq - tf.square(self.z_mu) -
tf.exp(self.z_log_sigma_sq), axis=1) tf.exp(self.z_log_sigma_sq), axis=1)
self.latent_loss = tf.reduce_mean(latent_loss) self.latent_loss = tf.reduce_mean(latent_loss)
self.total_loss = self.recon_loss + self.latent_loss self.total_loss = self.weight_recon * self.recon_loss + self.latent_loss
self.train_op = tf.train.AdamOptimizer( self.train_op = tf.train.AdamOptimizer(
learning_rate = self.learning_rate).minimize(self.total_loss, global_step=self.global_step) learning_rate = self.learning_rate).minimize(self.total_loss, global_step=self.global_step)
# Build a saver # Build a saver
...@@ -109,7 +112,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ...@@ -109,7 +112,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
@staticmethod @staticmethod
def vae_arc3(x,l_name=0,nz=16): def vae_arc3(x,l_name=0,nz=16):
seq_name = "sq_" + str(l_name) + "_" seq_name = "sq_" + str(l_name) + "_"
conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")
conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")
...@@ -121,12 +123,9 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ...@@ -121,12 +123,9 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps
z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1")
z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name + "decode_5")
seq_name + "decode_5") conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8,seq_name + "decode_6")
conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8,
seq_name + "decode_6")
x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")
return x_hat, z_mu, z_log_sigma_sq, z return x_hat, z_mu, z_log_sigma_sq, z
def vae_arc_all(self): def vae_arc_all(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment