From 6c7d45fda5f5fcd6a9079b00b25d1e0ea0b8bac9 Mon Sep 17 00:00:00 2001 From: Yan Ji <y.ji@fz-juelich.de> Date: Mon, 18 Jul 2022 15:23:38 +0200 Subject: [PATCH] debug for main_train.py --- .../era5_varmapping_template.json | 4 +--- .../main_scripts/main_train_models.py | 12 ++++++------ .../model_modules/model_architectures.py | 1 + .../video_prediction/datasets/__init__.py | 16 ---------------- .../video_prediction/models/__init__.py | 6 ------ 5 files changed, 8 insertions(+), 31 deletions(-) diff --git a/video_prediction_tools/data_preprocess/era5_varmapping_template.json b/video_prediction_tools/data_preprocess/era5_varmapping_template.json index cb83066c..d4aa1d16 100644 --- a/video_prediction_tools/data_preprocess/era5_varmapping_template.json +++ b/video_prediction_tools/data_preprocess/era5_varmapping_template.json @@ -9,9 +9,7 @@ # The value of the 'pl'-key denotes the pressure level (in Pa) onto which the data is interpolated # !!! This file should be only adapted if you are familiar with the ERA5 grib files!!! { -"surface":{ - ["2t", "tcc","msl","10u","10v"] - }, +"surface": ["2t", "tcc","msl","10u","10v"], "multi":{ "t" : { diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 5242c884..ddfc6d43 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -404,29 +404,29 @@ class TrainModel(object): fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"] self.saver_loss = fetch_list[-3] # ML: Is this a reasonable choice? self.saver_loss_name = "Loss" - if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": + elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")] # Add loss that is tracked self.saver_loss = fetch_list[-1][1] self.saver_loss_dict = fetch_list[-1][0] self.saver_loss_name = "Generator L1 loss" - if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"] self.saver_loss = fetch_list[-2] self.saver_loss_name = "Reconstruction loss" - if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel": + elif self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel": fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": + elif self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "WeatherBenchModel": + elif self.video_model.__class__.__name__ == "WeatherBenchModel": fetch_list = fetch_list + ["total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index 95738122..904862b1 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -7,6 +7,7 @@ def known_models(): model_mappings = { 'savp': 'SAVPVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel', + 'weatherBench': 'WeatherBenchModel' } return model_mappings diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py index cd0ec2b2..6102fcbe 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py @@ -1,15 +1,5 @@ -from .base_dataset import BaseVideoDataset -from .base_dataset import VideoDataset, SequenceExampleVideoDataset, VarLenFeatureVideoDataset -from .google_robot_dataset import GoogleRobotVideoDataset -from .sv2p_dataset import SV2PVideoDataset -from .softmotion_dataset import SoftmotionVideoDataset -from .kth_dataset import KTHVideoDataset -from .ucf101_dataset import UCF101VideoDataset -from .cartgripper_dataset import CartgripperVideoDataset from .era5_dataset import ERA5Dataset -from .moving_mnist import MovingMnist from data_preprocess.dataset_options import known_datasets -#from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly def get_dataset_class(dataset): dataset_mappings = known_datasets() @@ -18,12 +8,6 @@ def get_dataset_class(dataset): if dataset_class is None: raise ValueError('Invalid dataset %s' % dataset) else: - # ERA5Dataset movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset - #if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist": - # dataset_class = globals().get(dataset_class) - # if not issubclass(dataset_class,BaseVideoDataset): - # raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class)) - #else: dataset_class = globals().get(dataset_class) return dataset_class diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index 290def9f..0f6e833c 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -4,16 +4,10 @@ from .base_model import BaseVideoPredictionModel from .base_model import VideoPredictionModel -from .non_trainable_model import NonTrainableVideoPredictionModel -from .non_trainable_model import GroundTruthVideoPredictionModel -from .non_trainable_model import RepeatVideoPredictionModel from .savp_model import SAVPVideoPredictionModel -from .vanilla_vae_model import VanillaVAEVideoPredictionModel from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel -from .mcnet_model import McNetVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models -from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel from .weatherBench3DCNN import WeatherBenchModel def get_model_class(model): -- GitLab