diff --git a/video_prediction_tools/data_preprocess/era5_varmapping_template.json b/video_prediction_tools/data_preprocess/era5_varmapping_template.json index cb83066c939ef9564e5d11960f3eb81376395b65..d4aa1d16a18c50b37b42a3a8cd89c8da80214155 100644 --- a/video_prediction_tools/data_preprocess/era5_varmapping_template.json +++ b/video_prediction_tools/data_preprocess/era5_varmapping_template.json @@ -9,9 +9,7 @@ # The value of the 'pl'-key denotes the pressure level (in Pa) onto which the data is interpolated # !!! This file should be only adapted if you are familiar with the ERA5 grib files!!! { -"surface":{ - ["2t", "tcc","msl","10u","10v"] - }, +"surface": ["2t", "tcc","msl","10u","10v"], "multi":{ "t" : { diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 5242c8848bd5c0fa4ee8927532a2c8e7bf15743d..ddfc6d43cd24591e7a24f16395e9e6be1ee4fecc 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -404,29 +404,29 @@ class TrainModel(object): fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"] self.saver_loss = fetch_list[-3] # ML: Is this a reasonable choice? self.saver_loss_name = "Loss" - if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": + elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")] # Add loss that is tracked self.saver_loss = fetch_list[-1][1] self.saver_loss_dict = fetch_list[-1][0] self.saver_loss_name = "Generator L1 loss" - if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"] self.saver_loss = fetch_list[-2] self.saver_loss_name = "Reconstruction loss" - if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel": + elif self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel": fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": + elif self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "WeatherBenchModel": + elif self.video_model.__class__.__name__ == "WeatherBenchModel": fetch_list = fetch_list + ["total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index 9573812280cd3b8f808f55556bec7c546c745871..904862b150475c4e2b95b877ac451eed26f455dc 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -7,6 +7,7 @@ def known_models(): model_mappings = { 'savp': 'SAVPVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel', + 'weatherBench': 'WeatherBenchModel' } return model_mappings diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py index cd0ec2b230169016cc10aee5ee2ff3d7e4fc611b..6102fcbe50e53ffa3bb5411f77b068cdc9e67758 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py @@ -1,15 +1,5 @@ -from .base_dataset import BaseVideoDataset -from .base_dataset import VideoDataset, SequenceExampleVideoDataset, VarLenFeatureVideoDataset -from .google_robot_dataset import GoogleRobotVideoDataset -from .sv2p_dataset import SV2PVideoDataset -from .softmotion_dataset import SoftmotionVideoDataset -from .kth_dataset import KTHVideoDataset -from .ucf101_dataset import UCF101VideoDataset -from .cartgripper_dataset import CartgripperVideoDataset from .era5_dataset import ERA5Dataset -from .moving_mnist import MovingMnist from data_preprocess.dataset_options import known_datasets -#from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly def get_dataset_class(dataset): dataset_mappings = known_datasets() @@ -18,12 +8,6 @@ def get_dataset_class(dataset): if dataset_class is None: raise ValueError('Invalid dataset %s' % dataset) else: - # ERA5Dataset movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset - #if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist": - # dataset_class = globals().get(dataset_class) - # if not issubclass(dataset_class,BaseVideoDataset): - # raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class)) - #else: dataset_class = globals().get(dataset_class) return dataset_class diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index 290def9f7c934f871fedd8c1703bbc114c822dcd..0f6e833c6376e33b33810844c8428e80067646bb 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -4,16 +4,10 @@ from .base_model import BaseVideoPredictionModel from .base_model import VideoPredictionModel -from .non_trainable_model import NonTrainableVideoPredictionModel -from .non_trainable_model import GroundTruthVideoPredictionModel -from .non_trainable_model import RepeatVideoPredictionModel from .savp_model import SAVPVideoPredictionModel -from .vanilla_vae_model import VanillaVAEVideoPredictionModel from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel -from .mcnet_model import McNetVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models -from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel from .weatherBench3DCNN import WeatherBenchModel def get_model_class(model):