Skip to content
Snippets Groups Projects
Commit 6c7d45fd authored by Yan Ji's avatar Yan Ji
Browse files

debug for main_train.py

parent 9bb49049
Branches
Tags
No related merge requests found
......@@ -9,9 +9,7 @@
# The value of the 'pl'-key denotes the pressure level (in Pa) onto which the data is interpolated
# !!! This file should be only adapted if you are familiar with the ERA5 grib files!!!
{
"surface":{
["2t", "tcc","msl","10u","10v"]
},
"surface": ["2t", "tcc","msl","10u","10v"],
"multi":{
"t" : {
......
......@@ -404,29 +404,29 @@ class TrainModel(object):
fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"]
self.saver_loss = fetch_list[-3] # ML: Is this a reasonable choice?
self.saver_loss_name = "Loss"
if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
fetch_list = fetch_list + ["inputs", "total_loss"]
self.saver_loss = fetch_list[-1]
self.saver_loss_name = "Total loss"
if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")]
# Add loss that is tracked
self.saver_loss = fetch_list[-1][1]
self.saver_loss_dict = fetch_list[-1][0]
self.saver_loss_name = "Generator L1 loss"
if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"]
self.saver_loss = fetch_list[-2]
self.saver_loss_name = "Reconstruction loss"
if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":
elif self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":
fetch_list = fetch_list + ["inputs", "total_loss"]
self.saver_loss = fetch_list[-1]
self.saver_loss_name = "Total loss"
if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
elif self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
fetch_list = fetch_list + ["inputs", "total_loss"]
self.saver_loss = fetch_list[-1]
self.saver_loss_name = "Total loss"
if self.video_model.__class__.__name__ == "WeatherBenchModel":
elif self.video_model.__class__.__name__ == "WeatherBenchModel":
fetch_list = fetch_list + ["total_loss"]
self.saver_loss = fetch_list[-1]
self.saver_loss_name = "Total loss"
......
......@@ -7,6 +7,7 @@ def known_models():
model_mappings = {
'savp': 'SAVPVideoPredictionModel',
'convLSTM': 'VanillaConvLstmVideoPredictionModel',
'weatherBench': 'WeatherBenchModel'
}
return model_mappings
from .base_dataset import BaseVideoDataset
from .base_dataset import VideoDataset, SequenceExampleVideoDataset, VarLenFeatureVideoDataset
from .google_robot_dataset import GoogleRobotVideoDataset
from .sv2p_dataset import SV2PVideoDataset
from .softmotion_dataset import SoftmotionVideoDataset
from .kth_dataset import KTHVideoDataset
from .ucf101_dataset import UCF101VideoDataset
from .cartgripper_dataset import CartgripperVideoDataset
from .era5_dataset import ERA5Dataset
from .moving_mnist import MovingMnist
from data_preprocess.dataset_options import known_datasets
#from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly
def get_dataset_class(dataset):
dataset_mappings = known_datasets()
......@@ -18,12 +8,6 @@ def get_dataset_class(dataset):
if dataset_class is None:
raise ValueError('Invalid dataset %s' % dataset)
else:
# ERA5Dataset movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset
#if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist":
# dataset_class = globals().get(dataset_class)
# if not issubclass(dataset_class,BaseVideoDataset):
# raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class))
#else:
dataset_class = globals().get(dataset_class)
return dataset_class
......@@ -4,16 +4,10 @@
from .base_model import BaseVideoPredictionModel
from .base_model import VideoPredictionModel
from .non_trainable_model import NonTrainableVideoPredictionModel
from .non_trainable_model import GroundTruthVideoPredictionModel
from .non_trainable_model import RepeatVideoPredictionModel
from .savp_model import SAVPVideoPredictionModel
from .vanilla_vae_model import VanillaVAEVideoPredictionModel
from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
from .mcnet_model import McNetVideoPredictionModel
from .test_model import TestModelVideoPredictionModel
from model_modules.model_architectures import known_models
from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel
from .weatherBench3DCNN import WeatherBenchModel
def get_model_class(model):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment