From 9707d42d13f166939586afb96add8fd22a97aec4 Mon Sep 17 00:00:00 2001 From: gong1 <b.gong@fz-juelich.de> Date: Tue, 17 Jan 2023 17:55:22 +0100 Subject: [PATCH] update the modular of models --- .../model_modules/video_prediction/models/__init__.py | 2 +- .../video_prediction/models/weatherBench3Dcnn.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index b2814ecd..c2470aea 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -7,7 +7,7 @@ from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel - +from .weatherBench3Dcnn import WeatherBenchModel def get_model_class(model): model_mappings = known_models() diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py index ceba5355..033bd499 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py @@ -13,9 +13,6 @@ from .our_base_model import BaseModels class WeatherBenchModel(BaseModels): - filters = [64, 64, 64, 64, 2] - kernels = [5, 5, 5, 5, 5] - def __init__(self, hparams_dict_config: dict=None, mode:str="train", **kwargs): """ This is class for building weatherBench architecture by using updated hparameters @@ -58,8 +55,10 @@ class WeatherBenchModel(BaseModels): def build_model(self, x): """Fully convolutional network""" x = x[:, 0, :, :, :] - _idx = 0 + filters = [64, 64, 64, 64, 2] + kernels = [5, 5, 5, 5, 5] + for f, k in zip(filters[:-1], kernels[:-1]): with tf.variable_scope("conv_layer_"+str(_idx), reuse=tf.AUTO_REUSE): x = ld.conv_layer(x, kernel_size=k, stride=1, -- GitLab