Skip to content
Snippets Groups Projects
Commit dd78d680 authored by gong1's avatar gong1
Browse files

update vae model

parent b4fb67da
Branches tgf19ts
No related tags found
No related merge requests found
......@@ -287,7 +287,7 @@ def main():
#fetches["latent_loss"] = model.latent_loss
fetches["summary"] = model.summary_op
if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
fetches["global_step"] = model.global_step
fetches["total_loss"] = model.total_loss
#fetch the specific loss function only for mcnet
......@@ -295,6 +295,9 @@ def main():
fetches["L_p"] = model.L_p
fetches["L_gdl"] = model.L_gdl
fetches["L_GAN"] =model.L_GAN
if model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
fetches["latent_loss"] = model.latent_loss
fetches["recon_loss"] = model.recon_loss
results = sess.run(fetches)
train_losses.append(results["total_loss"])
#Fetch losses for validation data
......@@ -333,6 +336,8 @@ def main():
print ("Total_loss:{}".format(results["total_loss"]))
elif model.__class__.__name__ == "SAVPVideoPredictionModel":
print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"]))
elif model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"]))
else:
print ("The model name does not exist")
......
......@@ -21,6 +21,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
super(VanillaVAEVideoPredictionModel, self).__init__(mode, hparams_dict, hparams, **kwargs)
self.mode = mode
self.learning_rate = self.hparams.lr
self.weight_recon = self.hparams.weight_recon
self.nz = self.hparams.nz
self.aggregate_nccl=aggregate_nccl
self.gen_images_enc = None
......@@ -30,6 +31,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
self.latent_loss = None
self.total_loss = None
def get_default_hparams_dict(self):
"""
The keys of this dict define valid hyperparameters for instances of
......@@ -67,11 +69,12 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
nz=10,
context_frames=-1,
sequence_length=-1,
clip_length=10, #Bing: TODO What is the clip_length, original is 10,
weight_recon = 0.4
)
return dict(itertools.chain(default_hparams.items(), hparams.items()))
def build_graph(self,x)
def build_graph(self,x):
self.x = x["images"]
#self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
self.global_step = tf.train.get_or_create_global_step()
......@@ -82,7 +85,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
1 + self.z_log_sigma_sq - tf.square(self.z_mu) -
tf.exp(self.z_log_sigma_sq), axis=1)
self.latent_loss = tf.reduce_mean(latent_loss)
self.total_loss = self.recon_loss + self.latent_loss
self.total_loss = self.weight_recon * self.recon_loss + self.latent_loss
self.train_op = tf.train.AdamOptimizer(
learning_rate = self.learning_rate).minimize(self.total_loss, global_step=self.global_step)
# Build a saver
......@@ -109,7 +112,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
@staticmethod
def vae_arc3(x,l_name=0,nz=16):
seq_name = "sq_" + str(l_name) + "_"
conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")
conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")
......@@ -121,12 +123,9 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps
z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1")
z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
conv5 = ld.transpose_conv_layer(z3, 3, 2, 8,
seq_name + "decode_5")
conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8,
seq_name + "decode_6")
conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, seq_name + "decode_5")
conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8,seq_name + "decode_6")
x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")
return x_hat, z_mu, z_log_sigma_sq, z
def vae_arc_all(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment