Skip to content
Snippets Groups Projects
Commit 4f5441d2 authored by BING GONG's avatar BING GONG
Browse files

Add weatherBench models

parent a503159a
No related branches found
No related tags found
No related merge requests found
Pipeline #101962 passed
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
# Weather Bench models
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong"
__date__ = "2021-04-13"
import tensorflow as tf
from tensorflow.contrib.training import HParams
from model_modules.video_prediction.layers import layer_def as ld
from model_modules.video_prediction.losses import *
class WeatherBenchModel(object):
def __init__(self, hparams_dict=None):
"""
This is class for building weahterBench architecture by using updated hparameters
args:
mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
hparams_dict: dict, the dictionary contains the hparaemters names and values
"""
self.hparams_dict = hparams_dict
self.hparams = self.parse_hparams()
self.learning_rate = self.hparams.lr
self.filters = self.hparams.filters
self.kernels = self.hparams.kernes
self.context_frames = self.hparams.context_frames
self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length- self.context_frames
self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun
self.batch_size = self.hparams.batch_size
self.recon_weight = self.hparams.recon_weight
self.outputs = {}
self.total_loss = None
def get_default_hparams(self):
return HParams(**self.get_default_hparams_dict())
def parse_hparams(self):
"""
Parse the hparams setting to ovoerride the default ones
"""
parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
return parsed_hparams
def get_default_hparams_dict(self):
"""
The function that contains default hparams
Returns:
A dict with the following hyperparameters.
context_frames : the number of ground-truth frames to pass in at start.
sequence_length : the number of frames in the video sequence
max_epochs : the number of epochs to train model
lr : learning rate
loss_fun : the loss function
"""
hparams = dict(
context_frames =12,
sequence_length =24,
max_epochs = 20,
batch_size = 40,
lr = 0.001,
loss_fun = "mse",
shuffle_on_val= True,
filter = 4,
kernels = 4
)
return hparams
def build_graph(self, x):
self.is_build_graph = False
self.inputs = x
self.x = x["images"]
self.global_step = tf.train.get_or_create_global_step()
original_global_variables = tf.global_variables()
# Architecture
x_hat = self.build_model(x, self.filters, self.kernels, dr=0)
# Loss
self.total_loss = l1_loss(x[...,0], x_hat[...,0])
# Optimizer
self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars)
# outputs
self.outputs["total_loss"] = self.total_loss
# Summary op
tf.summary.scalar("total_loss", self.total_loss)
self.summary_op = tf.summary.merge_all()
global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
self.saveable_variables = [self.global_step] + global_variables
self.is_build_graph = True
return self.is_build_graph
def build_model(self, x,filters, kernels, dr=0):
"""Fully convolutional network"""
for f, k in zip(filters[:-1], kernels[:-1]):
x = PeriodicConv2D(x, f, k)
x = tf.nn.elu(x)
if dr > 0: x = tf.nn.dropout(x, dr)
output = PeriodicConv2D(x, filters[-1], kernels[-1])
return output
class PeriodicPadding2D(object):
def __init__(self, x, pad_width):
self.pad_width = pad_width
def call(self, inputs, **kwargs):
if self.pad_width == 0:
return inputs
inputs_padded = tf.concat(
[inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2)
# Zero padding in the lat direction
inputs_padded = tf.pad(inputs_padded, [[0, 0], [self.pad_width, self.pad_width], [0, 0], [0, 0]])
return inputs_padded
class PeriodicConv2D(object):
def __init__(self, filters, kernel_size, conv_kwargs={}):
self.filters = filters
self.kernel_size = kernel_size
self.conv_kwargs = conv_kwargs
if type(kernel_size) is not int:
assert kernel_size[0] == kernel_size[1], 'PeriodicConv2D only works for square kernels'
kernel_size = kernel_size[0]
self.pad_width = (kernel_size - 1) // 2
def call(self,inputs):
self.padding = PeriodicPadding2D(inputs, self.pad_width)
self.conv = ld.conv2D(self.padding, self.filters, self.kernel_size, padding='valid')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment