diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index b8793ad84ee1053e2d98c43941fcb6f6a7ee8088..b33ed8c570adc622315965c10fd98b141147dad7 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -13,6 +13,7 @@ def known_models(): 'convLSTM_gan': "ConvLstmGANVideoPredictionModel", 'ours_vae_l1': 'SAVPVideoPredictionModel', 'ours_gan': 'SAVPVideoPredictionModel', + "weatherBench": "WeatherBenchModel " } return model_mappings diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index a4b60965d6e03a7ccbb4197c3c6c237944a101c5..290def9f7c934f871fedd8c1703bbc114c822dcd 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -14,8 +14,7 @@ from .mcnet_model import McNetVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel - - +from .weatherBench3DCNN import WeatherBenchModel def get_model_class(model): model_mappings = known_models()