From add452bfa32d1a9a138f14cf9b678063c4ced361 Mon Sep 17 00:00:00 2001
From: BING GONG <b.gong@fz-juelich.de>
Date: Wed, 8 Jun 2022 18:09:12 +0200
Subject: [PATCH] minor bugs for weatherBench3DCNN model

---
 .../era5/weatherBench/model_hparams_template.json   | 13 +++++++++++++
 .../video_prediction/models/weatherBench3DCNN.py    |  5 +++--
 2 files changed, 16 insertions(+), 2 deletions(-)
 create mode 100644 video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json

diff --git a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json
new file mode 100644
index 00000000..4f3a43f1
--- /dev/null
+++ b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json
@@ -0,0 +1,13 @@
+
+{
+    "batch_size": 4,
+    "lr": 0.001,
+    "max_epochs":20,
+    "context_frames":12,
+    "loss_fun":"mse",
+    "opt_var": "0",
+    "shuffle_on_val":true
+}
+
+
+
diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
index 09107827..725e4138 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
@@ -99,7 +99,7 @@ class WeatherBenchModel(object):
         return self.is_build_graph
 
 
-   def build_model(self, x,filters, kernels, dr=0):
+    def build_model(self, x, filters, kernels, dr=0):
         """Fully convolutional network"""
         for f, k in zip(filters[:-1], kernels[:-1]):
             x = PeriodicConv2D(x, f, k)
@@ -109,7 +109,6 @@ class WeatherBenchModel(object):
         return output
 
 
-
 class PeriodicPadding2D(object):
     def __init__(self, x, pad_width):
 
@@ -118,6 +117,7 @@ class PeriodicPadding2D(object):
     def call(self, inputs, **kwargs):
         if self.pad_width == 0:
             return inputs
+
         inputs_padded = tf.concat(
             [inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2)
 
@@ -126,6 +126,7 @@ class PeriodicPadding2D(object):
         return inputs_padded
 
 
+
 class PeriodicConv2D(object):
 
     def __init__(self, filters, kernel_size, conv_kwargs={}):
-- 
GitLab