diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py new file mode 100644 index 0000000000000000000000000000000000000000..524d1de86999d99f3d583f5dd33e6daa798a3c96 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT +# Weather Bench models +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2021-04-13" + +import tensorflow as tf +from tensorflow.contrib.training import HParams +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.losses import * + +class WeatherBenchModel(object): + + def __init__(self, hparams_dict=None): + """ + This is class for building weahterBench architecture by using updated hparameters + args: + mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model + hparams_dict: dict, the dictionary contains the hparaemters names and values + """ + self.hparams_dict = hparams_dict + self.hparams = self.parse_hparams() + self.learning_rate = self.hparams.lr + self.filters = self.hparams.filters + self.kernels = self.hparams.kernes + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.predict_frames = self.sequence_length- self.context_frames + self.max_epochs = self.hparams.max_epochs + self.loss_fun = self.hparams.loss_fun + self.batch_size = self.hparams.batch_size + self.recon_weight = self.hparams.recon_weight + self.outputs = {} + self.total_loss = None + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + + def parse_hparams(self): + """ + Parse the hparams setting to ovoerride the default ones + """ + + parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams + + + def get_default_hparams_dict(self): + """ + The function that contains default hparams + Returns: + A dict with the following hyperparameters. + context_frames : the number of ground-truth frames to pass in at start. + sequence_length : the number of frames in the video sequence + max_epochs : the number of epochs to train model + lr : learning rate + loss_fun : the loss function + """ + hparams = dict( + context_frames =12, + sequence_length =24, + max_epochs = 20, + batch_size = 40, + lr = 0.001, + loss_fun = "mse", + shuffle_on_val= True, + filter = 4, + kernels = 4 + ) + return hparams + + + def build_graph(self, x): + self.is_build_graph = False + self.inputs = x + self.x = x["images"] + + self.global_step = tf.train.get_or_create_global_step() + original_global_variables = tf.global_variables() + + # Architecture + x_hat = self.build_model(x, self.filters, self.kernels, dr=0) + # Loss + self.total_loss = l1_loss(x[...,0], x_hat[...,0]) + # Optimizer + self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + + # outputs + self.outputs["total_loss"] = self.total_loss + + # Summary op + tf.summary.scalar("total_loss", self.total_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self.is_build_graph = True + return self.is_build_graph + + + def build_model(self, x,filters, kernels, dr=0): + """Fully convolutional network""" + for f, k in zip(filters[:-1], kernels[:-1]): + x = PeriodicConv2D(x, f, k) + x = tf.nn.elu(x) + if dr > 0: x = tf.nn.dropout(x, dr) + output = PeriodicConv2D(x, filters[-1], kernels[-1]) + return output + + +class PeriodicPadding2D(object): + def __init__(self, x, pad_width): + + self.pad_width = pad_width + + def call(self, inputs, **kwargs): + if self.pad_width == 0: + return inputs + inputs_padded = tf.concat( + [inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2) + + # Zero padding in the lat direction + inputs_padded = tf.pad(inputs_padded, [[0, 0], [self.pad_width, self.pad_width], [0, 0], [0, 0]]) + return inputs_padded + + +class PeriodicConv2D(object): + + def __init__(self, filters, kernel_size, conv_kwargs={}): + self.filters = filters + self.kernel_size = kernel_size + self.conv_kwargs = conv_kwargs + if type(kernel_size) is not int: + assert kernel_size[0] == kernel_size[1], 'PeriodicConv2D only works for square kernels' + kernel_size = kernel_size[0] + self.pad_width = (kernel_size - 1) // 2 + + def call(self,inputs): + self.padding = PeriodicPadding2D(inputs, self.pad_width) + self.conv = ld.conv2D(self.padding, self.filters, self.kernel_size, padding='valid') + + +