diff --git a/test/run_pytest.sh b/test/run_pytest.sh index ed2b531562282d8fc6a34c07cf8e46cad1a83460..83220d34a51379e93add931ae6e03e9491b5bce4 100644 --- a/test/run_pytest.sh +++ b/test/run_pytest.sh @@ -2,7 +2,7 @@ # Name of virtual environment #VIRT_ENV_NAME="vp_new_structure" -VIRT_ENV_NAME="env_hdfml" +VIRT_ENV_NAME="juwels_env" if [ -z ${VIRTUAL_ENV} ]; then if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then diff --git a/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..bd0357a180631f1ba7f0d0b1732af1d18aa878c3 --- /dev/null +++ b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json @@ -0,0 +1,15 @@ + +{ + "batch_size": 4, + "lr": 0.001, + "max_epochs":20, + "context_frames":12, + "sequence_length":24, + "loss_fun":"rmse", + "shuffle_on_val":false, + "recon_weight":0.6 + +} + + + diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index c02378a702f5d807210cbef890b507fc99a8c7c7..c6ac709d7d5d82487e60ab915bf7b41cd11ffabc 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -177,8 +177,8 @@ class TrainModel(object): self.inputs = self.iterator.get_next() #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model, # otherwise the model will raise error - if self.dataset == "era5" and self.model == "savp": - del self.inputs["T_start"] + #if self.dataset == "era5" and self.model == "savp": + # del self.inputs["T_start"] @@ -231,6 +231,7 @@ class TrainModel(object): self.num_examples = self.train_dataset.num_examples_per_epoch() self.steps_per_epoch = int(self.num_examples/batch_size) self.total_steps = self.steps_per_epoch * max_epochs + print("Batch size is {} ; max_epochs is {}; num_samples per epoch is {}; steps_per_epoch is {}, total steps is {}".format(batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) def restore(self,sess, checkpoints, restore_to_checkpoint_mapping=None): """ @@ -292,11 +293,15 @@ class TrainModel(object): self.create_fetches_for_train() # In addition to the loss, we fetch the optimizer self.results = sess.run(self.fetches) # ...and run it here! train_losses.append(self.results["total_loss"]) + print("t_start for training",self.results["inputs"]["T_start"]) + print("len of t_start per iteration",len(self.results["inputs"]["T_start"])) #Run and fetch losses for validation data val_handle_eval = sess.run(self.val_handle) self.create_fetches_for_val() self.val_results = sess.run(self.val_fetches,feed_dict={self.train_handle: val_handle_eval}) val_losses.append(self.val_results["total_loss"]) + print("t_start for validation",self.val_results["inputs"]["T_start"]) + print("len of t_start per iteration",len(self.val_results["inputs"]["T_start"])) self.write_to_summary() self.print_results(step,self.results) timeit_end = time.time() @@ -333,6 +338,8 @@ class TrainModel(object): if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": self.fetches_for_train_convLSTM() if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.fetches_for_train_savp() if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": self.fetches_for_train_vae() + if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":self.fetches_for_train_gan() + if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":self.fetches_for_train_convLSTM() return self.fetches def fetches_for_train_convLSTM(self): @@ -340,8 +347,7 @@ class TrainModel(object): Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users """ self.fetches["total_loss"] = self.video_model.total_loss - - + self.fetches["inputs"] = self.video_model.inputs def fetches_for_train_savp(self): @@ -353,7 +359,7 @@ class TrainModel(object): self.fetches["d_loss"] = self.video_model.d_loss self.fetches["g_loss"] = self.video_model.g_loss self.fetches["total_loss"] = self.video_model.g_loss - + self.fetches["inputs"] = self.video_model.inputs def fetches_for_train_mcnet(self): @@ -372,15 +378,19 @@ class TrainModel(object): self.fetches["recon_loss"] = self.video_model.recon_loss self.fetches["total_loss"] = self.video_model.total_loss + def fetches_for_train_gan(self): + self.fetches["total_loss"] = self.video_model.total_loss + def create_fetches_for_val(self): """ Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users """ if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.val_fetches = {"total_loss": self.video_model.g_loss} + self.val_fetches["inputs"] = self.video_model.inputs else: self.val_fetches = {"total_loss": self.video_model.total_loss} - + self.val_fetches["inputs"] = self.video_model.inputs self.val_fetches["summary"] = self.video_model.summary_op def write_to_summary(self): diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 68017b4597080b771b00860ab32cf693c0714d73..60dfef0032a7746d57734285a7b86328e0c74f50 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -422,6 +422,7 @@ class Postprocess(TrainModel): # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel] feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()} gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) + # sanity check on length of forecast sequence assert gen_images.shape[1] == self.sequence_length - 1, \ "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method) diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index ca602a954a107c9217942e7f01e4eae4c68d58bb..5836ab9fce48692252a4dbc44415b4a4f9e2c2c3 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -14,6 +14,8 @@ def known_models(): 'vae': 'VanillaVAEVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel', 'mcnet': 'McNetVideoPredictionModel', + 'gan': "VanillaGANVideoPredictionModel", + 'convLSTM_gan': "ConvLstmGANVideoPredictionModel", 'ours_vae_l1': 'SAVPVideoPredictionModel', 'ours_gan': 'SAVPVideoPredictionModel', } diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index d26697cbfd8ab6f94c5316651a33dc772195db28..ce62965a2c92432ffbf739e933279f91b69e355c 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -181,7 +181,9 @@ class ERA5Dataset(object): dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs)) else: dataset = dataset.repeat(self.num_epochs) - if self.mode == "val": dataset = dataset.repeat(20) + + if self.mode == "val": dataset = dataset.repeat(20) + num_parallel_calls = None if shuffle else 1 dataset = dataset.apply(tf.contrib.data.map_and_batch( parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index 960f608deed07e715190cdecb38efeb2eb4c5ace..2053aeed83a3606804af959e1c422d5cb39723a7 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -12,6 +12,10 @@ from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel from .mcnet_model import McNetVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models +from .vanilla_GAN_model import VanillaGANVideoPredictionModel +from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel + + def get_model_class(model): model_mappings = known_models() diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab3a423a001e903ec4ca9fe1bd7ec78e18dc731 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py @@ -0,0 +1,354 @@ +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong,Yanji" +__date__ = "2021-04-13" + +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from model_modules.video_prediction import ops, flow_ops +from model_modules.video_prediction.models import BaseVideoPredictionModel +from model_modules.video_prediction.models import networks +from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from model_modules.video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from tensorflow.contrib.training import HParams +from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel + +class batch_norm(object): + def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): + with tf.variable_scope(name): + self.epsilon = epsilon + self.momentum = momentum + self.name = name + + def __call__(self, x, train=True): + return tf.contrib.layers.batch_norm(x, + decay=self.momentum, + updates_collections=None, + epsilon=self.epsilon, + scale=True, + is_training=train, + scope=self.name) + +class ConvLstmGANVideoPredictionModel(object): + def __init__(self, mode='train', hparams_dict=None): + """ + This is class for building convLSTM_GAN architecture by using updated hparameters + args: + mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model + hparams_dict: dict, the dictionary contains the hparaemters names and values + """ + self.mode = mode + self.hparams_dict = hparams_dict + self.hparams = self.parse_hparams() + self.learning_rate = self.hparams.lr + self.total_loss = None + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.predict_frames = self.sequence_length - self.context_frames + self.max_epochs = self.hparams.max_epochs + self.loss_fun = self.hparams.loss_fun + self.batch_size = self.hparams.batch_size + self.recon_weight = self.hparams.recon_weight + self.bd1 = batch_norm(name = "dis1") + self.bd2 = batch_norm(name = "dis2") + self.bd3 = batch_norm(name = "dis3") + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + + def parse_hparams(self): + """ + Parse the hparams setting to ovoerride the default ones + """ + + parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams + + + def get_default_hparams_dict(self): + """ + The function that contains default hparams + Returns: + A dict with the following hyperparameters. + context_frames : the number of ground-truth frames to pass in at start. + sequence_length : the number of frames in the video sequence + max_epochs : the number of epochs to train model + lr : learning rate + loss_fun : the loss function + recon_wegiht : the weight for reconstrution loss + """ + hparams = dict( + context_frames=12, + sequence_length=24, + max_epochs = 20, + batch_size = 40, + lr = 0.001, + loss_fun = "cross_entropy", + shuffle_on_val= True, + recon_weight=0.99, + + ) + return hparams + + + def build_graph(self, x): + self.is_build_graph = False + self.inputs = x + self.x = x["images"] + self.width = self.x.shape.as_list()[3] + self.height = self.x.shape.as_list()[2] + self.channels = self.x.shape.as_list()[4] + self.global_step = tf.train.get_or_create_global_step() + original_global_variables = tf.global_variables() + # Architecture + self.define_gan() + #This is the loss function (RMSE): + #This is loss function only for 1 channel (temperature RMSE) + #generator los + self.total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss + self.D_loss = (1-self.recon_weight) * self.D_loss + if self.mode == "train": + if self.recon_weight == 1: + print("Only train generator- convLSTM") + self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + else: + print("Training distriminator") + self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars) + with tf.control_dependencies([self.D_solver]): + print("Training generator....") + self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + with tf.control_dependencies([self.G_solver]): + self.train_op = tf.assign_add(self.global_step,1) + else: + self.train_op = None + + self.outputs = {} + self.outputs["gen_images"] = self.gen_images + self.outputs["total_loss"] = self.total_loss + # Summary op + tf.summary.scalar("total_loss", self.total_loss) + tf.summary.scalar("D_loss", self.D_loss) + tf.summary.scalar("G_loss", self.G_loss) + tf.summary.scalar("D_loss_fake", self.D_loss_fake) + tf.summary.scalar("D_loss_real", self.D_loss_real) + tf.summary.scalar("recon_loss",self.recon_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self.is_build_graph = True + return self.is_build_graph + + def get_noise(self): + """ + Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel) + """ + self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.batch_size, self.sequence_length, self.height, self.width, self.channels]) + return self.noise + + @staticmethod + def lrelu(x, leak=0.2, name="lrelu"): + return tf.maximum(x, leak*x) + + @staticmethod + def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): + shape = input_.get_shape().as_list() + + with tf.variable_scope(scope or "Linear"): + matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, + tf.random_normal_initializer(stddev=stddev)) + bias = tf.get_variable("bias", [output_size], + initializer=tf.constant_initializer(bias_start)) + if with_w: + return tf.matmul(input_, matrix) + bias, matrix, bias + else: + return tf.matmul(input_, matrix) + bias + + @staticmethod + def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): + with tf.variable_scope(name): + w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], + initializer=tf.truncated_normal_initializer(stddev=stddev)) + conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') + + biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) + conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + + return conv + + @staticmethod + def bn(x, scope): + return tf.contrib.layers.batch_norm(x, + decay=0.9, + updates_collections=None, + epsilon=1e-5, + scale=True, + scope=scope) + + def generator(self): + """ + Function to build up the generator architecture + args: + input images: a input tensor with dimension (n_batch,sequence_length,height,width,channel) + """ + with tf.variable_scope("generator",reuse=tf.AUTO_REUSE): + layer_gen = self.convLSTM_network(self.x) + layer_gen_pred = layer_gen[:,self.context_frames-1:,:,:,:] + return layer_gen + + + def discriminator(self,vid): + """ + Function that get discriminator architecture + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis1") + conv1 = ConvLstmGANVideoPredictionModel.lrelu(conv1) + conv2 = tf.layers.conv3d(conv1,128,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis2") + conv2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(conv2)) + conv3 = tf.layers.conv3d(conv2,256,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis3") + conv3 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(conv3)) + conv4 = tf.layers.conv3d(conv3,512,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis4") + conv4 = ConvLstmGANVideoPredictionModel.lrelu(self.bd3(conv4)) + conv5 = tf.layers.conv3d(conv4,1,kernel_size=[2,4,4],strides=[1,1,1],padding="SAME",name="dis5") + conv5 = tf.reshape(conv5, [-1,1]) + conv5sigmoid = tf.nn.sigmoid(conv5) + return conv5sigmoid,conv5 + + def discriminator0(self,image): + """ + Function that get discriminator architecture + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + layer_disc = self.convLSTM_network(image) + layer_disc = layer_disc[:,self.context_frames-1:self.context_frames,:,:, 0:1] + return layer_disc + + def discriminator1(self,sequence): + """ + https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py + Function that give the possibility of a sequence of frames is ture of false + the input squence shape is like [batch_size,time_seq_length,height,width,channel] (e.g., self.x[:,:self.context_frames,:,:,:]) + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + print(sequence.shape) + x = sequence[:,:,:,:,0:1] # extract targeted variable + x = tf.transpose(x, [0,2,3,1,4]) # sequence shape is like: [batch_size,height,width,time_seq_length] + x = tf.reshape(x,[x.shape[0],x.shape[1],x.shape[2],x.shape[3]]) + print(x.shape) + net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) + net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'),scope='d_bn2')) + net = tf.reshape(net, [self.batch_size, -1]) + net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.linear(net, 1024, scope='d_fc3'),scope='d_bn3')) + out_logit = ConvLstmGANVideoPredictionModel.linear(net, 1, scope='d_fc4') + out = tf.nn.sigmoid(out_logit) + print(out.shape) + return out, out_logit + + def get_disc_loss(self): + """ + Return the loss of discriminator given inputs + """ + + real_labels = tf.ones_like(self.D_real) + gen_labels = tf.zeros_like(self.D_fake) + self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels)) + self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels)) + self.D_loss = self.D_loss_real + self.D_loss_fake + return self.D_loss + + + def get_gen_loss(self): + """ + Param: + num_images: the number of images the generator should produce, which is also the lenght of the real image + z_dim : the dimension of the noise vector, a scalar + Return the loss of generator given inputs + """ + real_labels = tf.ones_like(self.D_fake) + self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=real_labels)) + return self.G_loss + + def get_vars(self): + """ + Get trainable variables from discriminator and generator + """ + print("trinable_varialbes", len(tf.trainable_variables())) + self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] + self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] + print("self.disc_vars",self.disc_vars) + print("self.gen_vars",self.gen_vars) + + + def define_gan(self): + """ + Define gan architectures + """ + self.noise = self.get_noise() + self.gen_images = self.generator() + #!!!! the input of discriminator should be changed when use different discriminators + self.D_real, self.D_real_logits = self.discriminator(self.x[:,self.context_frames:,:,:,:]) + self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:,:,:,:]) + self.get_gen_loss() + self.get_disc_loss() + self.get_vars() + if self.loss_fun == "rmse": + self.recon_loss = tf.reduce_mean(tf.square(self.x[:, self.context_frames:,:,:,0] - self.gen_images[:,self.context_frames-1:,:,:,0])) + elif self.loss_fun == "cross_entropy": + x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1]) + x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:,:,:,0],[-1]) + bce = tf.keras.losses.BinaryCrossentropy() + self.recon_loss = bce(x_flatten,x_hat_predict_frames_flatten) + else: + raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") + + + @staticmethod + def convLSTM_cell(inputs, hidden): + y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset + channels = inputs.get_shape()[-1] + # conv lstm cell + cell_shape = y_0.get_shape().as_list() + channels = cell_shape[-1] + with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) + if hidden is None: + hidden = cell.zero_state(y_0, tf.float32) + output, hidden = cell(y_0, hidden) + output_shape = output.get_shape().as_list() + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction + x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") + print('x_hat shape is: ',x_hat.shape) + return x_hat, hidden + + def convLSTM_network(self,x): + network_template = tf.make_template('network',VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables + # create network + x_hat = [] + + #This is for training (optimization of convLSTM layer) + hidden_g = None + for i in range(self.sequence_length-1): + if i < self.context_frames: + x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g) + else: + x_1_g, hidden_g = network_template(x_1_g, hidden_g) + x_hat.append(x_1_g) + + # pack them all together + x_hat = tf.stack(x_hat) + self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim ???? yan: why? + print('self.x_hat shape is: ',self.x_hat.shape) + return self.x_hat + + + diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b0d61edcc2464492fbd00e733ff4ce0130c04a --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py @@ -0,0 +1,242 @@ +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2021=01-05" + + + +""" +This code implement take the following as references: +1) https://stackabuse.com/introduction-to-gans-with-python-and-tensorflow/ +2) cousera GAN courses +3) https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py +""" +import collections +import functools +import itertools +from collections import OrderedDict +import numpy as np +import tensorflow as tf +from tensorflow.python.util import nest +from model_modules.video_prediction import ops, flow_ops +from model_modules.video_prediction.models import BaseVideoPredictionModel +from model_modules.video_prediction.models import networks +from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat +from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell +from model_modules.video_prediction.utils import tf_utils +from datetime import datetime +from pathlib import Path +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from tensorflow.contrib.training import HParams + +class VanillaGANVideoPredictionModel(object): + def __init__(self, mode='train', hparams_dict=None): + """ + This is class for building vanilla GAN architecture by using updated hparameters + args: + mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model + hparams_dict: dict, the dictionary contains the hparaemters names and values + """ + self.mode = mode + self.hparams_dict = hparams_dict + self.hparams = self.parse_hparams() + self.learning_rate = self.hparams.lr + self.total_loss = None + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.predict_frames = self.sequence_length - self.context_frames + self.max_epochs = self.hparams.max_epochs + self.loss_fun = self.hparams.loss_fun + self.batch_size = self.hparams.batch_size + self.z_dim = self.hparams.z_dim #dim of noise-vector + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + + def parse_hparams(self): + """ + Parse the hparams setting to ovoerride the default ones + """ + + parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams + + + def get_default_hparams_dict(self): + """ + The function that contains default hparams + Returns: + A dict with the following hyperparameters. + context_frames : the number of ground-truth frames to pass in at start. + sequence_length : the number of frames in the video sequence + max_epochs : the number of epochs to train model + lr : learning rate + loss_fun : the loss function + """ + hparams = dict( + context_frames=12, + sequence_length=24, + max_epochs = 20, + batch_size = 40, + lr = 0.001, + loss_fun = "cross_entropy", + shuffle_on_val= True, + z_dim = 32, + ) + return hparams + + + def build_graph(self, x): + self.is_build_graph = False + self.x = x["images"] + self.width = self.x.shape.as_list()[3] + self.height = self.x.shape.as_list()[2] + self.channels = self.x.shape.as_list()[4] + self.n_samples = self.x.shape.as_list()[0] * self.x.shape.as_list()[1] + self.x = tf.reshape(self.x, [-1, self.height,self.width,self.channels]) + self.global_step = tf.train.get_or_create_global_step() + original_global_variables = tf.global_variables() + # Architecture + self.define_gan() + #This is the loss function (RMSE): + #This is loss function only for 1 channel (temperature RMSE) + if self.mode == "train": + self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars) + with tf.control_dependencies([self.D_solver]): + self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.G_loss, var_list=self.gen_vars) + with tf.control_dependencies([self.G_solver]): + self.train_op = tf.assign_add(self.global_step,1) + else: + self.train_op = None + self.total_loss = self.G_loss + self.D_loss + self.outputs = {} + self.outputs["gen_images"] = self.gen_images + self.outputs["total_loss"] = self.total_loss + # Summary op + self.loss_summary = tf.summary.scalar("total_loss", self.G_loss + self.D_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self.is_build_graph = True + return self.is_build_graph + + def get_noise(self): + """ + Function for creating noise: Given the dimensions (n_samples,z_dim) + """ + self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.n_samples, self.height, self.width, self.channels]) + return self.noise + + def get_generator_block(self,inputs,output_dim,idx): + + """ + Generator Block + Function for return a neural network of the generator given input and output dimensions + args: + inputs : the input vector + output_dim: the dimeniosn of output vector + return: + a generator neural network layer, with a convolutional layers followed by batch normalization and a relu activation + + """ + output1 = ld.conv_layer(inputs,kernel_size=2,stride=1,num_features=output_dim,idx=idx,activate="linear") + output2 = ld.bn_layers(output1,idx,is_training=False) + output3 = tf.nn.relu(output2) + return output3 + + + def generator(self,hidden_dim): + """ + Function to build up the generator architecture + args: + noise: a noise tensor with dimension (n_samples,height,width,channel) + hidden_dim: the inner dimension + """ + with tf.variable_scope("generator",reuse=tf.AUTO_REUSE): + layer1 = self.get_generator_block(self.noise,hidden_dim,1) + layer2 = self.get_generator_block(layer1,hidden_dim*2,2) + layer3 = self.get_generator_block(layer2,hidden_dim*4,3) + layer4 = self.get_generator_block(layer3,hidden_dim*8,4) + layer5 = ld.conv_layer(layer4,kernel_size=2,stride=1,num_features=self.channels,idx=5,activate="linear") + layer6 = tf.nn.sigmoid(layer5,name="6_conv") + print("layer6",layer6) + return layer6 + + + + def get_discriminator_block(self,inputs,output_dim,idx): + + """ + Distriminator block + Function for ruturn a neural network of a descriminator given input and output dimensions + + args: + inputs : the dimension of input vector + output_dim: the dimension of output dim + idx: : the index for the namespace of this block + Return: + a distriminator neural network layer with a convolutional layers followed by a leakyRelu function + """ + output1 = ld.conv_layer(inputs,2,stride=1,num_features=output_dim,idx=idx,activate="linear") + output2 = tf.nn.leaky_relu(output1) + return output2 + + + def discriminator(self,image,hidden_dim): + """ + Function that get discriminator architecture + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + layer1 = self.get_discriminator_block(image,hidden_dim,idx=1) + layer2 = self.get_discriminator_block(layer1,hidden_dim*4,idx=2) + layer3 = self.get_discriminator_block(layer2,hidden_dim*2,idx=3) + layer4 = self.get_discriminator_block(layer3, self.channels,idx=4) + layer5 = tf.nn.sigmoid(layer4) + return layer5 + + + def get_disc_loss(self): + """ + Return the loss of discriminator given inputs + """ + + real_labels = tf.ones_like(self.D_real) + gen_labels = tf.zeros_like(self.D_fake) + D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real, labels=real_labels)) + D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=gen_labels)) + self.D_loss = D_loss_real + D_loss_fake + return self.D_loss + + + def get_gen_loss(self): + """ + Param: + num_images: the number of images the generator should produce, which is also the lenght of the real image + z_dim : the dimension of the noise vector, a scalar + Return the loss of generator given inputs + """ + real_labels = tf.ones_like(self.gen_images) + self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=real_labels)) + return self.G_loss + + def get_vars(self): + """ + Get trainable variables from discriminator and generator + """ + self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] + self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] + + + + def define_gan(self): + """ + Define gan architectures + """ + self.noise = self.get_noise() + self.gen_images = self.generator(hidden_dim=8) + self.D_real = self.discriminator(self.x,hidden_dim=8) + self.D_fake = self.discriminator(self.gen_images,hidden_dim=8) + self.get_gen_loss() + self.get_disc_loss() + self.get_vars() + diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index 58172bca0401cdc2b2a4353ac2aeee092d59774a..1780b2e8439341320fe5726dab8f7174225a5956 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -8,6 +8,8 @@ from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from tensorflow.contrib.training import HParams + + class VanillaConvLstmVideoPredictionModel(object): def __init__(self, mode='train', hparams_dict=None): """ @@ -65,6 +67,7 @@ class VanillaConvLstmVideoPredictionModel(object): def build_graph(self, x): self.is_build_graph = False + self.inputs = x self.x = x["images"] self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables()