diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index b33ed8c570adc622315965c10fd98b141147dad7..79c4b5c67e8e5bd01fba57ec6a43cc70a06f107c 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -13,7 +13,7 @@ def known_models(): 'convLSTM_gan': "ConvLstmGANVideoPredictionModel", 'ours_vae_l1': 'SAVPVideoPredictionModel', 'ours_gan': 'SAVPVideoPredictionModel', - "weatherBench": "WeatherBenchModel " + "weatherBench": "WeatherBenchModel" } return model_mappings diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index 725e41388681093385bca7ea53905ad2e308e7b1..ca7b6f0b2ad86e2c62510fc21d9d0d66f8babf06 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -13,7 +13,7 @@ from model_modules.video_prediction.losses import * class WeatherBenchModel(object): - def __init__(self, hparams_dict=None): + def __init__(self, hparams_dict=None,**kwargs): """ This is class for building weahterBench architecture by using updated hparameters args: @@ -67,7 +67,8 @@ class WeatherBenchModel(object): loss_fun = "mse", shuffle_on_val= True, filter = 4, - kernels = 4 + kernels = 4, + ) return hparams