From 9707d42d13f166939586afb96add8fd22a97aec4 Mon Sep 17 00:00:00 2001
From: gong1 <b.gong@fz-juelich.de>
Date: Tue, 17 Jan 2023 17:55:22 +0100
Subject: [PATCH] update the modular of models

---
 .../model_modules/video_prediction/models/__init__.py      | 2 +-
 .../video_prediction/models/weatherBench3Dcnn.py           | 7 +++----
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
index b2814ecd..c2470aea 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
@@ -7,7 +7,7 @@ from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
 from .test_model import TestModelVideoPredictionModel
 from model_modules.model_architectures import known_models
 from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel
-
+from .weatherBench3Dcnn import  WeatherBenchModel
 
 def get_model_class(model):
     model_mappings = known_models()
diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py
index ceba5355..033bd499 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py
@@ -13,9 +13,6 @@ from .our_base_model import BaseModels
 
 class WeatherBenchModel(BaseModels):
 
-    filters = [64, 64, 64, 64, 2]
-    kernels = [5, 5, 5, 5, 5]
-
     def __init__(self, hparams_dict_config: dict=None, mode:str="train", **kwargs):
         """
         This is class for building weatherBench architecture by using updated hparameters
@@ -58,8 +55,10 @@ class WeatherBenchModel(BaseModels):
     def build_model(self, x):
         """Fully convolutional network"""
         x = x[:, 0, :, :, :]
-
         _idx = 0
+        filters = [64, 64, 64, 64, 2]
+        kernels = [5, 5, 5, 5, 5]
+
         for f, k in zip(filters[:-1], kernels[:-1]):
             with tf.variable_scope("conv_layer_"+str(_idx), reuse=tf.AUTO_REUSE):
                 x = ld.conv_layer(x, kernel_size=k, stride=1,
-- 
GitLab