From add452bfa32d1a9a138f14cf9b678063c4ced361 Mon Sep 17 00:00:00 2001 From: BING GONG <b.gong@fz-juelich.de> Date: Wed, 8 Jun 2022 18:09:12 +0200 Subject: [PATCH] minor bugs for weatherBench3DCNN model --- .../era5/weatherBench/model_hparams_template.json | 13 +++++++++++++ .../video_prediction/models/weatherBench3DCNN.py | 5 +++-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json diff --git a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json new file mode 100644 index 00000000..4f3a43f1 --- /dev/null +++ b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json @@ -0,0 +1,13 @@ + +{ + "batch_size": 4, + "lr": 0.001, + "max_epochs":20, + "context_frames":12, + "loss_fun":"mse", + "opt_var": "0", + "shuffle_on_val":true +} + + + diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index 09107827..725e4138 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -99,7 +99,7 @@ class WeatherBenchModel(object): return self.is_build_graph - def build_model(self, x,filters, kernels, dr=0): + def build_model(self, x, filters, kernels, dr=0): """Fully convolutional network""" for f, k in zip(filters[:-1], kernels[:-1]): x = PeriodicConv2D(x, f, k) @@ -109,7 +109,6 @@ class WeatherBenchModel(object): return output - class PeriodicPadding2D(object): def __init__(self, x, pad_width): @@ -118,6 +117,7 @@ class PeriodicPadding2D(object): def call(self, inputs, **kwargs): if self.pad_width == 0: return inputs + inputs_padded = tf.concat( [inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2) @@ -126,6 +126,7 @@ class PeriodicPadding2D(object): return inputs_padded + class PeriodicConv2D(object): def __init__(self, filters, kernel_size, conv_kwargs={}): -- GitLab