diff --git a/scripts/train_dummy.py b/scripts/train_dummy.py index 6ebdb70bf24ecc53fd9611a7af948842600cd0db..e3ccc7c90373323129e00414d44dad23100fd9e8 100644 --- a/scripts/train_dummy.py +++ b/scripts/train_dummy.py @@ -127,8 +127,8 @@ def main(): aggregate_nccl=args.aggregate_nccl) batch_size = model.hparams.batch_size - train_tf_dataset = train_dataset.make_dataset_v2(batch_size)#Bing: adopt the meteo data prepartion here - train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here + train_tf_dataset = train_dataset.make_dataset_v2(batch_size) + train_iterator = train_tf_dataset.make_one_shot_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated # and used to feed the `handle` placeholder. train_handle = train_iterator.string_handle() @@ -139,7 +139,7 @@ def main(): # train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) inputs = train_iterator.get_next() val = val_iterator.get_next() - # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles + model.build_graph(inputs) if not os.path.exists(args.output_dir): @@ -163,8 +163,8 @@ def main(): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) - #global_step = tf.train.get_or_create_global_step() - #global_step = tf.Variable(0, name = 'global_step', trainable = False) + + max_steps = model.hparams.max_steps print ("max_steps",max_steps) with tf.Session(config=config) as sess: @@ -175,15 +175,15 @@ def main(): #threads = tf.train.start_queue_runners(sess = sess, coord = coord) print("Init done: {sess.run(tf.local_variables_initializer())}%") model.restore(sess, args.checkpoint) - print("Restore processed finished") + #sess.run(model.post_init_ops) - print("Model run started") + #val_handle_eval = sess.run(val_handle) #print ("val_handle_val",val_handle_eval) #print("val handle done") sess.graph.finalize() start_step = sess.run(model.global_step) - print("global step done") + # start at one step earlier to log everything without doing any training # step is relative to the start_step @@ -191,7 +191,7 @@ def main(): global_step = sess.run(model.global_step) print ("global_step:", global_step) val_handle_eval = sess.run(val_handle) - print ("val_handle_val",val_handle_eval) + if step == 1: # skip step -1 and 0 for timing purposes (for warmstarting) start_time = time.time() @@ -211,9 +211,9 @@ def main(): run_start_time = time.time() #Run training results X = inputs["images"].eval(session=sess) - #results = sess.run(fetches,feed_dict={model.x:X}) #fetch the elements in dictinoary fetch + results = sess.run(fetches) - print ("results global step:",results["global_step"]) + run_elapsed_time = time.time() - run_start_time if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: print('running train_op took too long (%0.1fs)' % run_elapsed_time) @@ -229,8 +229,8 @@ def main(): summary_writer.add_summary(results["summary"]) summary_writer.add_summary(val_results["summary"]) - #print("results_global_step", results["global_step"]) - #print("Val_results_global_step", val_results["global_step"]) + + val_datasets = [val_dataset] val_models = [model] @@ -262,20 +262,20 @@ def main(): print("image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) - # if results['d_losses']: - # print("d_loss", results["d_loss"]) - # for name, loss in results['d_losses'].items(): - # print(" ", name, loss) - # if results['g_losses']: - # print("g_loss", results["g_loss"]) - # for name, loss in results['g_losses'].items(): - # print(" ", name, loss) - #for name, loss in results['total_loss'].items(): + + + + + + + + + print(" Results_total_loss",results["total_loss"]) print("saving model to", args.output_dir) saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)##Bing: cheat here a little bit because of the global step issue print("done") - #global_step = global_step + 1 + if __name__ == '__main__': main() diff --git a/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction/models/vanilla_convLSTM_model.py index 8cd2ad3f2b99e9a88c9471db2c0dc6f4ccb89913..225d4a5493158ab77dfb182f7f1a45fa5156286e 100644 --- a/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction/models/vanilla_convLSTM_model.py @@ -72,13 +72,13 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): def build_graph(self, x): self.x = x["images"] - #self.global_step = tf.train.get_or_create_global_step() + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) original_global_variables = tf.global_variables() # ARCHITECTURE self.x_hat_context_frames, self.x_hat_predict_frames = self.convLSTM_network() self.x_hat = tf.concat([self.x_hat_context_frames, self.x_hat_predict_frames], 1) - print("x_hat,shape", self.x_hat) + self.context_frames_loss = tf.reduce_mean( tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) @@ -102,13 +102,13 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): @staticmethod def convLSTM_cell(inputs, hidden, nz=16): - print("Inputs shape", inputs.shape) + conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu") - print("Encode_1_shape", conv1.shape) + conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu") - print("Encode 2_shape,", conv2.shape) + conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu") - print("Encode 3_shape, ", conv3.shape) + y_0 = conv3 # conv lstm cell cell_shape = y_0.get_shape().as_list() @@ -116,23 +116,23 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [3, 3], num_features = 8) if hidden is None: hidden = cell.zero_state(y_0, tf.float32) - print("hidden zero layer", hidden.shape) + output, hidden = cell(y_0, hidden) - print("output for cell:", output) + output_shape = output.get_shape().as_list() - print("output_shape,", output_shape) + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu") - print("conv5 shape", conv5) + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu") - print("conv6 shape", conv6) + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid") # set activation to linear - print("x hat shape", x_hat) + return x_hat, hidden def convLSTM_network(self): diff --git a/video_prediction/models/vanilla_vae_model.py b/video_prediction/models/vanilla_vae_model.py index 74280896dca61007c1b361ec4caff9ad5f718d26..eec5598305044226280080d630313487c7d847a4 100644 --- a/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction/models/vanilla_vae_model.py @@ -73,38 +73,37 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): def build_graph(self,x): - #global_step = tf.train.get_or_create_global_step() - #original_global_variables = tf.global_variables() - # self.x = x["images"] - #print ("self_x:",self.x) - #tf.reset_default_graph() - #self.x = tf.placeholder(tf.float32, [None,20,64,64,3]) + + + + + tf.set_random_seed(12345) self.x = x["images"] - - #self.global_step = tf.train.get_or_create_global_step() + + self.global_step = tf.Variable(0, name = 'global_step', trainable = False) original_global_variables = tf.global_variables() self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step') self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all() - # Loss - # Reconstruction loss - # Minimize the cross-entropy loss - # epsilon = 1e-10 - # recon_loss = -tf.reduce_sum( - # self.x[:,1:,:,:,:] * tf.log(epsilon+self.x_hat[:,:-1,:,:,:]) + - # (1-self.x[:,1:,:,:,:]) * tf.log(epsilon+1-self.x_hat[:,:-1,:,:,:]), - # axis=1 - # ) - - # self.recon_loss = tf.reduce_mean(recon_loss) + + + + + + + + + + + self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0])) - # Latent loss - # KL divergence: measure the difference between two distributions - # Here we measure the divergence between - # the latent distribution and N(0, 1) + + + + latent_loss = -0.5 * tf.reduce_sum( 1 + self.z_log_sigma_sq - tf.square(self.z_mu) - tf.exp(self.z_log_sigma_sq), axis = 1) @@ -113,7 +112,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): self.train_op = tf.train.AdamOptimizer( learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) # Build a saver - #self.saver = tf.train.Saver(tf.global_variables()) + self.losses = { 'recon_loss': self.recon_loss, 'latent_loss': self.latent_loss, @@ -125,14 +124,14 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): self.loss_summary = tf.summary.scalar("latent_loss", self.latent_loss) self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss) self.summary_op = tf.summary.merge_all() - # H(x, x_hat) = -\Sigma x*log(x_hat) + (1-x)*log(1-x_hat) - # self.ckpt = tf.train.Checkpoint(model=self.vae_arc2()) - # self.manager = tf.train.CheckpointManager(self.ckpt,self.checkpoint_dir,max_to_keep=3) + + + self.outputs = {} self.outputs["gen_images"] = self.x_hat global_variables = [var for var in tf.global_variables() if var not in original_global_variables] self.saveable_variables = [self.global_step] + global_variables - #train_op = tf.assign_add(global_step, 1) + return @@ -141,18 +140,18 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): seq_name = "sq_" + str(l_name) + "_" conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1") - print("Encode_1_shape", conv1.shape) # (?,2,2,8) - # conv2 - conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") # (?,2,2,8) - print("Encode 2_shape,", conv2.shape) - # conv3 - conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") # (?,1,1,8) - print("Encode 3_shape, ", conv3.shape) - # flatten + + + conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2") + + + conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3") + + conv4 = tf.layers.Flatten()(conv3) - print("Encode 4_shape, ", conv4.shape) + conv3_shape = conv3.get_shape().as_list() - print("conv4_shape",conv3_shape) + z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m") z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma') @@ -163,16 +162,16 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]]) - # conv5 + conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, - seq_name + "decode_5") # (16,1,1,8)inputs, kernel_size, stride, num_features - print("Decode 5 shape", conv5.shape) + seq_name + "decode_5") + conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, - seq_name + "decode_6") # (16,1,1,8)inputs, kernel_size, stride, num_features + seq_name + "decode_6") - # x_1 - x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") # set activation to linear - print("X_hat", x_hat.shape) + + x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8") + return x_hat, z_mu, z_log_sigma_sq, z def vae_arc_all(self): @@ -187,6 +186,6 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): x_hat = tf.stack(X, axis = 1) z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1) z_mu_all = tf.stack(z_mu_all, axis = 1) - print("X_hat", x_hat.shape) - print("zlog_sigma_sq_all", z_log_sigma_sq_all.shape) + + return x_hat, z_log_sigma_sq_all, z_mu_all