diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index b2814ecdf01af4e373e6bdfb165a6b53f1d03e00..c2470aeadd8c7decaa60b315f5fdf017307ec5a9 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -7,7 +7,7 @@ from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel - +from .weatherBench3Dcnn import WeatherBenchModel def get_model_class(model): model_mappings = known_models() diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py index ceba535518b3baa1579ba4634d98ba3127753b86..033bd499ef5717391d9fab8cced03b268135b4c9 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3Dcnn.py @@ -13,9 +13,6 @@ from .our_base_model import BaseModels class WeatherBenchModel(BaseModels): - filters = [64, 64, 64, 64, 2] - kernels = [5, 5, 5, 5, 5] - def __init__(self, hparams_dict_config: dict=None, mode:str="train", **kwargs): """ This is class for building weatherBench architecture by using updated hparameters @@ -58,8 +55,10 @@ class WeatherBenchModel(BaseModels): def build_model(self, x): """Fully convolutional network""" x = x[:, 0, :, :, :] - _idx = 0 + filters = [64, 64, 64, 64, 2] + kernels = [5, 5, 5, 5, 5] + for f, k in zip(filters[:-1], kernels[:-1]): with tf.variable_scope("conv_layer_"+str(_idx), reuse=tf.AUTO_REUSE): x = ld.conv_layer(x, kernel_size=k, stride=1,