diff --git a/video_prediction_tools/data_preprocess/era5_varmapping_template.json b/video_prediction_tools/data_preprocess/era5_varmapping_template.json
index cb83066c939ef9564e5d11960f3eb81376395b65..d4aa1d16a18c50b37b42a3a8cd89c8da80214155 100644
--- a/video_prediction_tools/data_preprocess/era5_varmapping_template.json
+++ b/video_prediction_tools/data_preprocess/era5_varmapping_template.json
@@ -9,9 +9,7 @@
 #              The value of the 'pl'-key denotes the pressure level (in Pa) onto which the data is interpolated
 #              !!! This file should be only adapted if you are familiar with the ERA5 grib files!!!
 {
-"surface":{
-   ["2t", "tcc","msl","10u","10v"]
-       },
+"surface": ["2t", "tcc","msl","10u","10v"],
 
 "multi":{
     "t" : {
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 5242c8848bd5c0fa4ee8927532a2c8e7bf15743d..ddfc6d43cd24591e7a24f16395e9e6be1ee4fecc 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -404,29 +404,29 @@ class TrainModel(object):
             fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"]
             self.saver_loss = fetch_list[-3]  # ML: Is this a reasonable choice?
             self.saver_loss_name = "Loss"
-        if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+        elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
             fetch_list = fetch_list + ["inputs", "total_loss"]
             self.saver_loss = fetch_list[-1]
             self.saver_loss_name = "Total loss"
-        if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
+        elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
             fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")]
             # Add loss that is tracked
             self.saver_loss = fetch_list[-1][1]                
             self.saver_loss_dict = fetch_list[-1][0]
             self.saver_loss_name = "Generator L1 loss"
-        if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
+        elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
             fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"]
             self.saver_loss = fetch_list[-2]
             self.saver_loss_name = "Reconstruction loss"
-        if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":
+        elif self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":
             fetch_list = fetch_list + ["inputs", "total_loss"]
             self.saver_loss = fetch_list[-1]
             self.saver_loss_name = "Total loss"
-        if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
+        elif self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
             fetch_list = fetch_list + ["inputs", "total_loss"]
             self.saver_loss = fetch_list[-1]
             self.saver_loss_name = "Total loss"
-        if self.video_model.__class__.__name__ == "WeatherBenchModel":
+        elif self.video_model.__class__.__name__ == "WeatherBenchModel":
             fetch_list = fetch_list + ["total_loss"]
             self.saver_loss = fetch_list[-1]
             self.saver_loss_name = "Total loss"
diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py
index 9573812280cd3b8f808f55556bec7c546c745871..904862b150475c4e2b95b877ac451eed26f455dc 100644
--- a/video_prediction_tools/model_modules/model_architectures.py
+++ b/video_prediction_tools/model_modules/model_architectures.py
@@ -7,6 +7,7 @@ def known_models():
     model_mappings = {
         'savp': 'SAVPVideoPredictionModel',
         'convLSTM': 'VanillaConvLstmVideoPredictionModel',
+        'weatherBench': 'WeatherBenchModel'
     }
 
     return model_mappings
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
index cd0ec2b230169016cc10aee5ee2ff3d7e4fc611b..6102fcbe50e53ffa3bb5411f77b068cdc9e67758 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
@@ -1,15 +1,5 @@
-from .base_dataset import BaseVideoDataset
-from .base_dataset import VideoDataset, SequenceExampleVideoDataset, VarLenFeatureVideoDataset
-from .google_robot_dataset import GoogleRobotVideoDataset
-from .sv2p_dataset import SV2PVideoDataset
-from .softmotion_dataset import SoftmotionVideoDataset
-from .kth_dataset import KTHVideoDataset
-from .ucf101_dataset import UCF101VideoDataset
-from .cartgripper_dataset import CartgripperVideoDataset
 from .era5_dataset import ERA5Dataset
-from .moving_mnist import MovingMnist
 from data_preprocess.dataset_options import known_datasets
-#from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly
 
 def get_dataset_class(dataset):
     dataset_mappings = known_datasets()
@@ -18,12 +8,6 @@ def get_dataset_class(dataset):
     if dataset_class is None:
         raise ValueError('Invalid dataset %s' % dataset)
     else:
-        # ERA5Dataset  movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset
-        #if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist":
-        #    dataset_class = globals().get(dataset_class)
-        #    if not issubclass(dataset_class,BaseVideoDataset):
-        #        raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class))
-        #else:
         dataset_class = globals().get(dataset_class)
 
     return dataset_class
diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
index 290def9f7c934f871fedd8c1703bbc114c822dcd..0f6e833c6376e33b33810844c8428e80067646bb 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
@@ -4,16 +4,10 @@
 
 from .base_model import BaseVideoPredictionModel
 from .base_model import VideoPredictionModel
-from .non_trainable_model import NonTrainableVideoPredictionModel
-from .non_trainable_model import GroundTruthVideoPredictionModel
-from .non_trainable_model import RepeatVideoPredictionModel
 from .savp_model import SAVPVideoPredictionModel
-from .vanilla_vae_model import VanillaVAEVideoPredictionModel
 from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
-from .mcnet_model import McNetVideoPredictionModel
 from .test_model import TestModelVideoPredictionModel
 from model_modules.model_architectures import known_models
-from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel
 from .weatherBench3DCNN import WeatherBenchModel
 
 def get_model_class(model):