diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh index a6da4643636ab7997a72cfc81f9311de6d7e8527..80d4de5266bc57c944bd57ffa5359512b4f23a4b 100644 --- a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh +++ b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh @@ -23,9 +23,9 @@ VIRT_ENV_NAME="my_venv" # Activate virtual environment if needed (and possible) if [ -z ${VIRTUAL_ENV} ]; then - if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then echo "Activating virtual environment..." - source ../${VIRT_ENV_NAME}/bin/activate + source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate else echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." exit 1 diff --git a/video_prediction_tools/data_preprocess/preprocess_data_step2.py b/video_prediction_tools/data_preprocess/preprocess_data_step2.py index a062ee1cefc0b0683f59a4d86736a4500243761e..a197471b22f28cc1c3bae9fe29bd7279d2015cde 100644 --- a/video_prediction_tools/data_preprocess/preprocess_data_step2.py +++ b/video_prediction_tools/data_preprocess/preprocess_data_step2.py @@ -218,7 +218,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): X_end = X_start + self.sequence_length seq = X_train[X_start:X_end, ...] # recording the start point of the timestamps (already datetime-objects) - t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0]) + + t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start]) seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars))) if not sequences: last_start_sequence_iter = sequence_iter diff --git a/video_prediction_tools/env_setup/install_venv_container.sh b/video_prediction_tools/env_setup/install_venv_container.sh index f0d53cdb495d5df6dffd30b79a19f53e1a0b2e98..3e5c35b9c4d635179fafab47ee65e153b80d2380 100755 --- a/video_prediction_tools/env_setup/install_venv_container.sh +++ b/video_prediction_tools/env_setup/install_venv_container.sh @@ -67,13 +67,15 @@ echo "Actiavting virtual environment ${VENV_NAME} to install required Python mod ACT_VENV="${VENV_DIR}/bin/activate" source "${VENV_DIR}/bin/activate" # set PYTHONPATH... -export PYTHONPATH="/usr/local/lib/python3.8/dist-packages/" +export PYTHONPATH=/usr/local/lib/python3.8/dist-packages/:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/virtual_envs/${VENV_NAME}/lib/python3.8/site-packages:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/utils:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/postprocess:$PYTHONPATH # ... also ensure that PYTHONPATH is appended when activating the virtual environment... -echo 'export PYTHONPATH="/usr/local/lib/python3.8/dist-packages/"' >> "${ACT_VENV}" +echo 'export PYTHONPATH=/usr/local/lib/python3.8/dist-packages/:$PYTHONPATH' >> "${ACT_VENV}" +echo 'export PYTHONPATH='${WORKING_DIR}'/virtual_envs/'${VENV_NAME}'/lib/python3.8/site-packages:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}':$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/utils:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/model_modules:$PYTHONPATH' >> ${ACT_VENV} diff --git a/video_prediction_tools/env_setup/modules_preprocess+extract.sh b/video_prediction_tools/env_setup/modules_preprocess+extract.sh index c867554716e49f9fbe5c66275a158fefd505f927..7976201ab97cdc14b9ab3418e86898defc48fdf7 100755 --- a/video_prediction_tools/env_setup/modules_preprocess+extract.sh +++ b/video_prediction_tools/env_setup/modules_preprocess+extract.sh @@ -1,33 +1,32 @@ #!/usr/bin/env bash -# __author__ = Bing Gong, Michael Langguth -# __date__ = '2020_06_26' +# __author__ = Michael Langguth +# __date__ = '2022_02_07' -# This script loads the required modules for ambs on Juwels and HDF-ML. -# Note that some other packages have to be installed into a venv (see create_env.sh and requirements.txt). +# This script loads the required modules for AMBS on JSC's HPY_systems (HDF-ML, Juwels Cluster and Juwels Booster). +# Further Python-packages may be installed in the virtual environment created by create_env.sh +# (see also requirements.txt). -HOST_NAME=`hostname` +HOST_NAME=$(hostname) echo "Start loading modules on ${HOST_NAME} required for preprocessing..." -echo "modules_preprocess.sh is subject to: " +echo "modules_preprocess+extract.sh is used for: " +echo "* data_extraction_era5.sh" echo "* preprocess_data_era5_step1.sh" +echo "* generate_runscript.py" module purge -module use $OTHERSTAGES -ml Stages/2019a -ml GCC/8.3.0 -ml ParaStationMPI/5.2.2-1 -ml mpi4py/3.0.1-Python-3.6.8 -# serialized version is not available on HFML -# see https://gitlab.version.fz-juelich.de/haf/Wiki/-/wikis/HDF-ML%20System -if [[ "${HOST_NAME}" == hdfml* ]]; then - ml h5py/2.9.0-serial-Python-3.6.8 -elif [[ "${HOST_NAME}" == juwels* ]]; then - ml h5py/2.9.0-Python-3.6.8 -fi -ml SciPy-Stack/2019a-Python-3.6.8 -ml scikit/2019a-Python-3.6.8 -ml netcdf4-python/1.5.0.1-Python-3.6.8 +module use "$OTHERSTAGES" +ml Stages/2020 +ml GCC/10.3.0 +ml GCCcore/.10.3.0 +ml ParaStationMPI/5.4.10-1 +ml mpi4py/3.0.3-Python-3.8.5 +ml h5py/2.10.0-Python-3.8.5 +ml netcdf4-python/1.5.4-Python-3.8.5 +ml SciPy-Stack/2021-Python-3.8.5 +ml scikit/2021-Python-3.8.5 +ml CDO/2.0.0rc3 # clean up if triggered via script argument if [[ $1 == purge ]]; then diff --git a/video_prediction_tools/env_setup/requirements.txt b/video_prediction_tools/env_setup/requirements.txt index 9f433734a966541c6c6a20a6387a499716b2d80a..28b7c6f83865095745ccab685b08c60aba8a71f9 100755 --- a/video_prediction_tools/env_setup/requirements.txt +++ b/video_prediction_tools/env_setup/requirements.txt @@ -3,5 +3,7 @@ mpi4py==3.0.1 pandas==0.25.3 xarray==0.16.0 basemap==1.3.0 +numpy==1.17.3 # although this numpy-version is in the container, we set it here to avoid any further installation scikit-image==0.18.1 opencv-python-headless==4.2.0.34 +netcdf4 diff --git a/video_prediction_tools/env_setup/wrapper_container.sh b/video_prediction_tools/env_setup/wrapper_container.sh index fea29a0a9018a5436122389164cfff0859f22552..cfe716bee9f610b4a44988fc2ff6e4be048d06b4 100755 --- a/video_prediction_tools/env_setup/wrapper_container.sh +++ b/video_prediction_tools/env_setup/wrapper_container.sh @@ -27,6 +27,8 @@ export PYTHONPATH=/usr/local/lib/python3.8/dist-packages:$PYTHONPATH # ... and modules from this project export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/utils:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/postprocess:$PYTHONPATH # Control echo "****** Check PYTHONPATH *****" diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json index 2275e60f543badb1367351a50938e7bcacf2f119..f36e1c0b44279ad2e4f9e741c7bfade0a5aa0a05 100644 --- a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json @@ -1,5 +1,5 @@ { - "batch_size": 4, + "batch_size": 32, "lr": 0.0002, "beta1": 0.5, "beta2": 0.999, @@ -12,9 +12,11 @@ "gan_feature_cdist_weight": 0.0, "state_weight": 0.0, "nz": 16, - "max_epochs":2, + "max_epochs":4, "context_frames": 12, - "opt_var": "0" + "opt_var": "0", + "decay_steps":[3000,9000], + "end_lr": 0.00000008 } diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 7ccddc88c66128bcab07104a818a9ff73faa3316..9e58de96a31913eb19678e151fac5c46d6e80409 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -11,14 +11,15 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong, Michael Langguth" __date__ = "2020-10-22" +import os, glob import argparse import errno import json -import os from typing import Union, List import random import time import numpy as np +import xarray as xr import tensorflow as tf from model_modules.video_prediction import datasets, models import matplotlib.pyplot as plt @@ -26,12 +27,13 @@ import pickle as pkl from model_modules.video_prediction.utils import tf_utils from general_utils import * import math - +import shutil class TrainModel(object): def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None, - gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float=None, prob_save_model:float=None): + gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.001, + frac_start_save: float = None, frac_intv_save: float = None): """ Class instance for training the models :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located @@ -44,12 +46,9 @@ class TrainModel(object): :param gpu_mem_frac: fraction of GPU memory to be preallocated :param seed: seed of the randomizers :param args: list of arguments passed - :param diag_intv_frac: interval for diagnozing and saving model; the fraction with respect to the number of - steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results - into a diagnozing intreval of 10 iteration steps (= interval over which validation loss - is averaged to identify best model performance) - :param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints - :param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0) + :param diag_intv_frac: interval for diagnozing the model (create loss-curves and save pickle-file with losses) + :param frac_start_save: fraction of total iterations steps to start checkpointing the model + :param frac_intv_save: fraction of total iterations steps for checkpointing the model """ self.input_dir = os.path.normpath(input_dir) self.output_dir = os.path.normpath(output_dir) @@ -62,8 +61,8 @@ class TrainModel(object): self.seed = seed self.args = args self.diag_intv_frac = diag_intv_frac - self.frac_save_model_start = frac_save_model_start - self.prob_save_model = prob_save_model + self.frac_start_save = frac_start_save + self.frac_intv_save = frac_intv_save # for diagnozing and saving the model during training self.saver_loss = None # set in create_fetches_for_train-method self.saver_loss_name = None # set in create_fetches_for_train-method @@ -74,18 +73,16 @@ class TrainModel(object): self.set_seed() self.get_model_hparams_dict() self.load_params_from_checkpoints_dir() - self.setup_dataset() - self.setup_model() + self.setup_datasets() self.make_dataset_iterator() + self.setup_model() self.setup_graph() self.save_dataset_model_params_to_checkpoint_dir(dataset=self.train_dataset,video_model=self.video_model) self.count_parameters() self.create_saver_and_writer() self.setup_gpu_config() - self.calculate_samples_and_epochs() self.calculate_checkpoint_saver_conf() - def set_seed(self): """ Set seed to control the same train/val/testing dataset for the same seed @@ -148,16 +145,23 @@ class TrainModel(object): except FileNotFoundError: print("%{0}: model_hparams.json does not exist in {1}".format(method, self.checkpoint_dir)) - def setup_dataset(self): + def setup_datasets(self): """ Setup train and val dataset instance with the corresponding data split configuration. Simultaneously, sequence_length is attached to the hyperparameter dictionary. """ + # get some parameters from the model hyperparameters + self.batch_size = self.model_hparams_dict_load["batch_size"] + self.max_epochs = self.model_hparams_dict_load["max_epochs"] + # create dataset instance VideoDataset = datasets.get_dataset_class(self.dataset) self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) + self.calculate_samples_and_epochs() + self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) + # set-up validation dataset and calculate number of batches for calculating validation loss self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, - hparams_dict_config=self.model_hparams_dict) + hparams_dict_config=self.model_hparams_dict, nsamples_ref=self.num_examples) # Retrieve sequence length from dataset self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) @@ -239,15 +243,17 @@ class TrainModel(object): """ method = TrainModel.calculate_samples_and_epochs.__name__ - batch_size = self.video_model.hparams.batch_size - max_epochs = self.video_model.hparams.max_epochs # the number of epochs self.num_examples = self.train_dataset.num_examples_per_epoch() - self.steps_per_epoch = int(self.num_examples/batch_size) - self.diag_intv_step = int(self.diag_intv_frac*self.steps_per_epoch) - self.total_steps = self.steps_per_epoch * max_epochs + self.steps_per_epoch = int(self.num_examples/self.batch_size) + self.total_steps = self.steps_per_epoch * self.max_epochs + self.diag_intv_step = int(self.diag_intv_frac*self.total_steps) + if self.diag_intv_step == 0: + self.diag_intv_step = 1 + else: + pass print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}" - .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) - + .format(method, self.batch_size, self.max_epochs, self.num_examples, self.steps_per_epoch, + self.total_steps)) def calculate_checkpoint_saver_conf(self): """ @@ -256,17 +262,18 @@ class TrainModel(object): """ method = TrainModel.calculate_checkpoint_saver_conf.__name__ - if hasattr(self.total_steps, "attr_name"): - raise SyntaxError(" function 'calculate_sample_and_epochs' is required to call to calcualte the total_step before all function {}".format(method)) - if self.prob_save_model > 1 or self.prob_save_model<0 : - raise ValueError("pro_save_model should be less than 1 and larger than 0") - if self.frac_save_model_start > 1 or self.frac_save_model_start<0: - raise ValueError("frac_save_model_start should be less than 1 and larger than 0") - - self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start)) - self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model)) - print("The model will be saved starting from step {} with {} interval step ".format(str(self.start_checkpoint_step),self.saver_interval_step)) + if not hasattr(self, "total_steps"): + raise RuntimeError("%{0} self.total_steps is still unset. Run calculate_samples_and_epochs beforehand" + .format(method)) + if self.frac_intv_save > 1 or self.frac_intv_save<0 : + raise ValueError("%{0}: frac_intv_save must be less than 1 and larger than 0".format(method)) + if self.frac_start_save > 1 or self.frac_start_save < 0: + raise ValueError("%{0}: frac_start_save must be less than 1 and larger than 0".format(method)) + self.chp_start_step = int(math.ceil(self.total_steps * self.frac_start_save)) + self.chp_intv_step = int(math.ceil(self.total_steps * self.frac_intv_save)) + print("%{0}: Model will be saved after step {1:d} at each {2:d} interval step " + .format(method, self.chp_start_step,self.chp_intv_step)) def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): """ @@ -313,20 +320,14 @@ class TrainModel(object): def create_checkpoints_folder(self, step:int=None): """ - Create a folder to store checkpoint at certain step - :param step: the step you want to save the checkpoint + Create a folder to store checkpoint at certain step. + :param step: the iteration step corresponding to the checkpoint return : dir path to save model """ - dir_name = "checkpoint_" + str(step) - full_dir_name = os.path.join(self.output_dir,dir_name) - if os.path.isfile(os.path.join(full_dir_name,"checkpoints")): - print("The checkpoint at step {} exists".format(step)) - else: - os.mkdir(full_dir_name) + full_dir_name = os.path.join(self.output_dir, "checkpoint_{0:d}".format(step)) + os.makedirs(full_dir_name, exist_ok=True) return full_dir_name - - def train_model(self): """ Start session and train the model by looping over all iteration steps @@ -335,7 +336,6 @@ class TrainModel(object): self.global_step = tf.train.get_or_create_global_step() with tf.Session(config=self.config) as sess: - print("parameter_count =", sess.run(self.parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) self.restore(sess, self.checkpoint) @@ -347,7 +347,6 @@ class TrainModel(object): # initialize auxiliary variables time_per_iteration = [] run_start_time = time.time() - # perform iteration for step in range(start_step, self.total_steps): timeit_start = time.time() @@ -368,22 +367,22 @@ class TrainModel(object): time_iter = time.time() - timeit_start time_per_iteration.append(time_iter) print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter)) - - if step > self.start_checkpoint_step and (step % self.saver_interval_step == 0 or step == self.total_steps - 1): + if (step >= self.chp_start_step and (step-self.chp_start_step)%self.chp_intv_step == 0) or \ + step == self.total_steps - 1: #create a checkpoint folder for step full_dir_name = self.create_checkpoints_folder(step=step) self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step) # pickle file and plots are always created - TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) - TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) + if step % self.diag_intv_step == 0 or step == self.total_steps - 1: + TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) + TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) - # Final diagnostics - # track time (save to pickle-files) + # Final diagnostics: training track time and save to pickle-files) train_time = time.time() - run_start_time - results_dict = {"train_time": train_time, - "total_steps": self.total_steps} - TrainModel.save_results_to_dict(results_dict,self.output_dir) + results_dict = {"train_time": train_time, "total_steps": self.total_steps} + TrainModel.save_results_to_dict(results_dict, self.output_dir) + print("%{0}: Training loss decreased from {1:.6f} to {2:.6f}:" .format(method, np.mean(train_losses[0:10]), np.mean(train_losses[-self.diag_intv_step:]))) print("%{0}: Validation loss decreased from {1:.6f} to {2:.6f}:" @@ -441,7 +440,6 @@ class TrainModel(object): if not self.saver_loss: raise AttributeError("%{0}: saver_loss is still not set. create_fetches_for_train must be run in advance." .format(method)) - if self.saver_loss_dict: fetch_list = ["summary_op", (self.saver_loss_dict, self.saver_loss)] else: @@ -497,39 +495,13 @@ class TrainModel(object): print ("Total_loss:{}".format(results["total_loss"])) elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}" - .format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"],results["gen_l1_loss"])) + .format(results["g_losses"], results["d_losses"], results["g_loss"], results["d_loss"], + results["gen_l1_loss"])) elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": - print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"])) + print("Total_loss:{}; latent_losses:{}; reconst_loss:{}" + .format(results["total_loss"], results["latent_loss"], results["recon_loss"])) else: - print("%{0}: Printing results of the model {1} is not implemented yet".format(method, self.video_model.__class__.__name__)) - - @staticmethod - def set_model_saver_flag(losses: List, old_min_loss: float, niter_steps: int = 100): - """ - Sets flag to save the model given that a new minimum in the loss is readched - :param losses: list of losses over iteration steps - :param old_min_loss: previous loss - :param niter_steps: number of iteration steps over which the loss is averaged - :return flag: True if model should be saved - :return loss_avg: updated minimum loss - """ - method = TrainModel.set_model_saver_flag.__name__ - - save_flag = False - if len(losses) <= niter_steps*2: - loss_avg = old_min_loss - return save_flag, loss_avg - - loss_avg = np.mean(losses[-niter_steps:]) - # print diagnosis - print("%{0}: Current loss: {1:.4f}, old minimum: {2:.4f}, model will be saved: {3}" - .format(method, loss_avg, old_min_loss, loss_avg < old_min_loss)) - if loss_avg < old_min_loss: - save_flag = True - else: - loss_avg = old_min_loss - - return save_flag, loss_avg + print("%{0}: Printing results of model '{1}' is not implemented yet".format(method, self.video_model.__class__.__name__)) @staticmethod def plot_train(train_losses, val_losses, loss_name, output_dir): @@ -584,6 +556,162 @@ class TrainModel(object): pkl.dump(loss_per_iteration_val,f) + +class BestModelSelector(object): + """ + Class to select the best performing model from multiple checkpoints created during training + """ + + def __init__(self, model_dir: str, eval_metric: str, criterion: str = "min", channel: int = 0, seed: int = 42): + """ + Class to retrieve the best model checkpoint. The last one is also retained. + :param model_dir: path to directory where checkpoints are saved (the trained model output directory) + :param eval_metric: evaluation metric for model selection (must be implemented in Scores) + :param criterion: set to 'min' ('max') for negatively (positively) oriented metrics + :param channel: channel of data used for selection + :param seed: seed for the Postprocess-instance + """ + method = self.__class__.__name__ + # sanity check + if not os.path.isdir(model_dir): + raise NotADirectoryError("{0}: The passed directory '{1}' does not exist".format(method, model_dir)) + assert criterion in ["min", "max"], "%{0}: criterion must be either 'min' or 'max'.".format(method) + # set class attributes + self.seed = seed + self.channel = channel + self.metric = eval_metric + self.checkpoint_base_dir = model_dir + self.checkpoints_all = BestModelSelector.get_checkpoints_dirs(model_dir) + self.ncheckpoints = len(self.checkpoints_all) + # evaluate all checkpoints... + self.checkpoints_eval_all = self.run(self.metric) + # ... and finalize by choosing the best model and cleaning up + _ = self.finalize(criterion) + + def run(self, eval_metric): + """ + Runs eager postprocessing on all checkpoints with evaluation of chosen metric + :param eval_metric: the target evaluation metric + :return: Populated self.checkpoints_eval_all where the average of the metric over all forecast hours is listed + + """ + method = BestModelSelector.run.__name__ + from main_visualize_postprocess import Postprocess + metric_avg_all = [] + + for checkpoint in self.checkpoints_all: + print("Start to evalute checkpoint:", checkpoint) + results_dir_eager = os.path.join(checkpoint, "results_eager") + eager_eval = Postprocess(results_dir=results_dir_eager, checkpoint=checkpoint, data_mode="val", batch_size=32, + seed=self.seed, eval_metrics=[eval_metric], channel=self.channel, frac_data=0.33, + lquick=True) + eager_eval.run() + eager_eval.handle_eval_metrics() + + eval_metric_ds = eager_eval.eval_metrics_ds + + metric_avg_all.append(BestModelSelector.get_avg_var(eval_metric_ds, "avg")) + print("Checkpoint {} is evaluated".format(checkpoint)) + + return metric_avg_all + + def finalize(self, criterion): + """ + Choose the best performing model checkpoint and delete all checkpoints apart from the best and the final ones + :return: status if everything runs + """ + method = BestModelSelector.finalize.__name__ + + best_ind = self.get_best_checkpoint(criterion) + if best_ind == self.ncheckpoints -1: + print("%{0}: Last model checkpoint performs best ({1}: {2:.5f}) and is retained exclusively." + .format(method, self.metric, self.checkpoints_eval_all[-1])) + else: + print("%{0}: The last ({1}: {2:.5f}) and the best ({1}: {3:.5f}) model checkpoint are retained." + .format(method, self.metric, self.checkpoints_eval_all[-1], self.checkpoints_eval_all[best_ind])) + + stat = self.clean_checkpoints(best_ind) + return stat + + def get_best_checkpoint(self, criterion: str): + """ + Choose the best performing model checkpoint + :param criterion: "max" or "min" + :return: index of best checkpoint in terms of evaluation metric + """ + method = BestModelSelector.get_best_checkpoint.__name__ + + if not self.checkpoints_eval_all: + raise AttributeError("%{0}: checkpoints_eval_all is still empty. run-method must be executed beforehand." + .format(method)) + + if criterion == "min": + best_index = np.argmin(self.checkpoints_eval_all) + else: + best_index = np.argmax(self.checkpoints_eval_all) + + return best_index + + def clean_checkpoints(self, best_ind: int): + """ + Delete all checkpoints apart from the best and the final ones + :param best_ind: index of best performing checkpoint + :return: status + """ + method = BestModelSelector.clean_checkpoints.__name__ + + # list of checkpoints to keep (while ensuring uniqueness!) + checkpoints_keep = list({self.checkpoints_all[best_ind], self.checkpoints_all[-1]}) + print("%{0}: The following checkpoints are retained: \n * {1}".format(method, "\n* ".join(checkpoints_keep))) + # drop checkpoints of interest from removal-list + checkpoints_op = self.checkpoints_all.copy() + for keep in checkpoints_keep: + checkpoints_op.remove(keep) + + for dir_path in checkpoints_op: + shutil.rmtree(dir_path) + print("%{0}: The checkpoint directory {1} was removed.".format(method, dir_path)) + + return True + + @staticmethod + def get_checkpoints_dirs(model_dir): + """ + Function to obtain all checkpoint directories in a list. + :param model_dir: path to directory where checkpoints are saved (the trained model output directory) + :return: list of all checkpoint directories in model_dir + """ + method = BestModelSelector.get_checkpoints_dirs.__name__ + + checkpoints_all = glob.glob(os.path.join(model_dir, "checkpoint*/")) + ncheckpoints = len(checkpoints_all) + if ncheckpoints == 0: + raise FileExistsError("{0}: No checkpoint folders found under '{1}'".format(method, model_dir)) + else: + # glob.glob yiels unsorted directories, i.e. do the soring now + checkpoints_all = sorted(checkpoints_all, key=lambda x: int(x.split("_")[-1].replace("/",""))) + print("%{0}: {1:d} checkpoints directories has been found.".format(method, ncheckpoints)) + + return checkpoints_all + + @staticmethod + def get_avg_var(ds: xr.Dataset, varname_substr: str): + """ + Retrieves and averages variable from dataset + :param ds: the dataset + :param varname_substr: the name of the variable or a substring suifficient to retrieve the variable + :return: the averaged variable + """ + varnames = list(ds.variables) + var_in_file = [s for s in varnames if varname_substr in s] + try: + var_mean = ds[var_in_file[0]].mean().values + except Exception as err: + raise err + + return var_mean + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, @@ -595,18 +723,20 @@ def main(): parser.add_argument("--model", type=str, help="Model class name") parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters") parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use") - parser.add_argument("--frac_save_model_start", type=float,default=0.6,help="fraction of the start step for saving checkpoint") - parser.add_argument("--prob_save_model", type = float, default = 0.01, help = "probabability that model are saved to checkpoint (control the frequences of saving model") + parser.add_argument("--frac_start_save", type=float, default=1., + help="Fraction of all iteration steps after which checkpointing starts.") + parser.add_argument("--frac_intv_save", type=float, default=0.01, + help="Fraction of all iteration steps to define the saving interval.") parser.add_argument("--seed", default=1234, type=int) args = parser.parse_args() # start timing for the whole run - timeit_start_total_time = time.time() - #create a training instance + timeit_start = time.time() + # create a training instance train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict, model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset, - gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_save_model_start=args.frac_save_model_start, - prob_save_model=args.prob_save_model) + gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_start_save=args.frac_start_save, + frac_intv_save=args.frac_intv_save) print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): @@ -618,9 +748,18 @@ def main(): # train model train_time, time_per_iteration = train_case.train_model() - - total_run_time = time.time() - timeit_start_total_time - train_case.save_timing_to_pkl(total_run_time, train_time, time_per_iteration, args.output_dir) - + timeit_after_train = time.time() + train_case.save_timing_to_pkl(timeit_after_train - timeit_start, train_time, time_per_iteration, args.output_dir) + + # select best model + if args.dataset == "era5" and args.frac_start_save < 1.: + _ = BestModelSelector(args.output_dir, "mse") + timeit_finish = time.time() + print("Selecting the best model checkpoint took {0:.2f} minutes.".format((timeit_finish - timeit_after_train)/60.)) + else: + timeit_finish = time.time() + print("Total time elapsed {0} minutes.".format((timeit_finish - timeit_start)/60.)) + + if __name__ == '__main__': main() diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 65df7e4abc3991cf6ae6d81987a46608611fa911..f95cfd79d9439a3009a0ca60c29fe57559024b00 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -26,7 +26,7 @@ from normalization import Norm_data from netcdf_datahandling import get_era5_varatts from general_utils import check_dir from metadata import MetaData as MetaData -from main_scripts.main_train_models import * +from main_train_models import TrainModel from data_preprocess.preprocess_data_step2 import * from model_modules.video_prediction import datasets, models, metrics from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, calculate_cond_quantiles, Scores @@ -34,31 +34,32 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea class Postprocess(TrainModel): - def __init__(self, results_dir: str = None, checkpoint: str = None, mode: str = "test", batch_size: int = None, - num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None, - seed: int = None, channel: int = 0, args=None, run_mode: str = "deterministic", - eval_metrics: List = ("mse", "psnr", "ssim", "acc"), - clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly", - lquick: bool = None): + def __init__(self, results_dir: str = None, checkpoint: str = None, data_mode: str = "test", batch_size: int = None, + gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, + seed: int = None, channel: int = 0, run_mode: str = "deterministic", lquick: bool = None, + frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), args=None, + clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly"): """ Initialization of the class instance for postprocessing (generation of forecasts from trained model + basic evauation). :param results_dir: output directory to save results :param checkpoint: directory point to the model checkpoints - :param mode: mode of dataset to be processed ("train", "val" or "test"), default: "test" + :param data_mode: mode of dataset to be processed ("train", "val" or "test"), default: "test" :param batch_size: mini-batch size for generating forecasts from trained model + :param gpu_mem_frac: fraction of GPU memory to be preallocated :param num_stochastic_samples: number of ensemble members for variational models (SAVP, VAE), default: 1 not supported yet!!! :param stochastic_plot_id: not supported yet! - :param gpu_mem_frac: fraction of GPU memory to be pre-allocated :param seed: Integer controlling randomization :param channel: Channel of interest for statistical evaluation - :param args: namespace of parsed arguments :param run_mode: "deterministic" or "stochastic", default: "deterministic", "stochastic is not supported yet!!! + :param lquick: flag for quick evaluation + :param frac_data: fraction of dataset to be used for evaluation (only applied when shuffling is active) :param eval_metrics: metrics used to evaluate the trained model :param clim_path: the path to the netCDF-file storing climatolgical data - :param lquick: flag for quick evaluation + :param args: namespace of parsed arguments """ + tf.reset_default_graph() # copy over attributes from parsed argument self.results_dir = self.output_dir = os.path.normpath(results_dir) _ = check_dir(self.results_dir, lcreate=True) @@ -75,9 +76,10 @@ class Postprocess(TrainModel): self.checkpoint += "/" # trick to handle checkpoint-directory and file simulataneously self.clim_path = clim_path self.run_mode = run_mode - self.mode = mode + self.data_mode = data_mode self.channel = channel self.lquick = lquick + self.frac_data = frac_data # Attributes set during runtime self.norm_cls = None # configuration of basic evaluation @@ -85,8 +87,9 @@ class Postprocess(TrainModel): self.nboots_block = 1000 self.block_length = 7 * 24 # this corresponds to a block length of 7 days in case of hourly forecasts # initialize evrything to get an executable Postprocess instance - self.save_args_to_option_json() # create options.json-in results directory - self.copy_data_model_json() # copy over JSON-files from model directory + if args is not None: + self.save_args_to_option_json() # create options.json in results directory + self.copy_data_model_json() # copy over JSON-files from model directory # get some parameters related to model and dataset self.datasplit_dict, self.model_hparams_dict, self.dataset, self.model, self.input_dir_tfr = self.load_jsons() self.model_hparams_dict_load = self.get_model_hparams_dict() @@ -104,18 +107,19 @@ class Postprocess(TrainModel): self.stat_fl = self.set_stat_file() self.cond_quantile_vars = self.init_cond_quantile_vars() # setup test dataset and model - self.test_dataset, self.num_samples_per_epoch = self.setup_test_dataset() + self.test_dataset, self.num_samples_per_epoch = self.setup_dataset() + if lquick and self.test_dataset.shuffled: + self.num_samples_per_epoch = Postprocess.reduce_samples(self.num_samples_per_epoch, frac_data) # self.num_samples_per_epoch = 100 # reduced number of epoch samples -> useful for testing self.sequence_length, self.context_frames, self.future_length = self.get_data_params() self.inputs, self.input_ts = self.make_test_dataset_iterator() + self.data_clim = None + if "acc" in eval_metrics: + self.load_climdata() # set-up model, its graph and do GPU-configuration (from TrainModel) - self.setup_model(mode=self.mode) + self.setup_model(mode="test") self.setup_graph() self.setup_gpu_config() - if "acc" in eval_metrics: - self.load_climdata() - else: - self.data_clim = None # Methods that are called during initialization def get_input_dirs(self): @@ -153,11 +157,11 @@ class Postprocess(TrainModel): method_name = Postprocess.copy_data_model_json.__name__ # correctness of self.checkpoint and self.results_dir is already checked in __init__ - checkpoint_dir = os.path.dirname(self.checkpoint) - model_opt_js = os.path.join(checkpoint_dir, "options.json") - model_ds_js = os.path.join(checkpoint_dir, "dataset_hparams.json") - model_hp_js = os.path.join(checkpoint_dir, "model_hparams.json") - model_dd_js = os.path.join(checkpoint_dir, "data_split.json") + model_outdir = os.path.split(os.path.dirname(self.checkpoint))[0] + model_opt_js = os.path.join(model_outdir, "options.json") + model_ds_js = os.path.join(model_outdir, "dataset_hparams.json") + model_hp_js = os.path.join(model_outdir, "model_hparams.json") + model_dd_js = os.path.join(model_outdir, "data_split.json") if os.path.isfile(model_opt_js): shutil.copy(model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json")) @@ -241,14 +245,6 @@ class Postprocess(TrainModel): print("%{0}: Something went wrong when getting metadata from file '{1}'".format(method_name, metadata_fl)) raise err - # when the metadat is loaded without problems, the follwoing will work - self.height, self.width = md_instance.ny, md_instance.nx - self.vars_in = md_instance.variables - - self.lats = xr.DataArray(md_instance.lat, coords={"lat": md_instance.lat}, dims="lat", - attrs={"units": "degrees_east"}) - self.lons = xr.DataArray(md_instance.lon, coords={"lon": md_instance.lon}, dims="lon", - attrs={"units": "degrees_north"}) return md_instance def load_climdata(self,clim_path="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly", @@ -284,20 +280,22 @@ class Postprocess(TrainModel): coords_new["month"] = np.arange(1, 13) coords_new["hour"] = np.arange(0, 24) # initialize a new data array with explicit dimensions for month and hour - data_clim_new = xr.DataArray(np.full((12, 24, nlat, nlon), np.nan), coords=coords_new, dims=["month", "hour", "lat", "lon"]) + data_clim_new = xr.DataArray(np.full((12, 24, nlat, nlon), np.nan), coords=coords_new, + dims=["month", "hour", "lat", "lon"]) # do the reorganization for month in np.arange(1, 13): data_clim_new.loc[dict(month=month)]=dt_clim.sel(time=dt_clim["time.month"]==month) self.data_clim = data_clim_new[dict(lon=meta_lon_loc,lat=meta_lat_loc)] - def setup_test_dataset(self): + def setup_dataset(self): """ setup the test dataset instance :return test_dataset: the test dataset instance """ VideoDataset = datasets.get_dataset_class(self.dataset) - test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.mode, datasplit_config=self.datasplit_dict) + test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.data_mode, + datasplit_config=self.datasplit_dict) nsamples = test_dataset.num_examples_per_epoch() return test_dataset, nsamples @@ -391,14 +389,14 @@ class Postprocess(TrainModel): if not hasattr(self, "num_stochastic_samples"): raise AttributeError("%{0}: Attribute num_stochastic_samples is still unset".format(method)) - if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet': + if np.any(self.model in ["convLSTM", "test_model", "mcnet"]): if self.num_stochastic_samples > 1: print("Number of samples for deterministic model cannot be larger than 1. Higher values are ignored.") self.num_stochastic_samples = 1 # the run-factory def run(self): - if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet': + if np.any(self.model in ["convLSTM", "test_model", "mcnet"]): self.run_deterministic() elif self.run_mode == "deterministic": self.run_deterministic() @@ -530,7 +528,7 @@ class Postprocess(TrainModel): nsamples, self.future_length) cond_quantiple_ds = None - while sample_ind < self.num_samples_per_epoch: + while sample_ind < nsamples: # get normalized and denormalized input data input_results, input_images_denorm, t_starts = self.get_input_data_per_batch(self.inputs) # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel] @@ -584,7 +582,8 @@ class Postprocess(TrainModel): # safe dataset with evaluation metrics for later use self.eval_metrics_ds = eval_metric_ds self.cond_quantiple_ds = cond_quantiple_ds - + self.sess.close() + # all methods of the run factory def init_session(self): """ @@ -672,7 +671,7 @@ class Postprocess(TrainModel): # dictionary of implemented evaluation metrics dims = ["lat", "lon"] - eval_metrics_func = [Scores(metric,dims).score_func for metric in self.eval_metrics] + eval_metrics_func = [Scores(metric, dims).score_func for metric in self.eval_metrics] varname_ref = "{0}_ref".format(varname) # reset init-time coordinate of metric_ds in place and get indices for slicing ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch) @@ -875,6 +874,24 @@ class Postprocess(TrainModel): plot_cond_quantile(quantile_panel_lbr, cond_variable_lbr, plt_fname_lbr) + @staticmethod + def reduce_samples(nsamples: int, frac_data: float): + """ + Reduce number of sample for Postprocessing + :param nsamples: original number of samples + :param frac_data: fraction of samples used for evaluation + :return: reduced number of samples + """ + method = Postprocess.reduce_samples.__name__ + + if frac_data <= 0. or frac_data >= 1.: + print("%{0}: frac_data is not within [0..1] and is therefore ignored.".format(method)) + return nsamples + else: + nsamples_new = int(np.ceil(nsamples*frac_data)) + print("%{0}: Sample size is reduced from {1:d} to {2:d}".format(method, int(nsamples), nsamples_new)) + return nsamples_new + @staticmethod def clean_obj_attribute(obj, attr_name, lremove=False): """ @@ -1028,7 +1045,11 @@ class Postprocess(TrainModel): var_pickle.extend(var_origin_pickle) # Retrieve starting index - ind = list(time_pickle).index(np.array(ts_persistence[0])) + try: + ind = list(time_pickle).index(np.array(ts_persistence[0])) + except Exception as err: + print("Please consider return Data preprocess step 1 to generate entire month data") + raise err var_persistence = np.array(var_pickle)[ind:ind + len(ts_persistence)] time_persistence = np.array(time_pickle)[ind:ind + len(ts_persistence)].ravel() @@ -1127,10 +1148,10 @@ class Postprocess(TrainModel): raise NotADirectoryError("%{0}: The directory to store the netCDf-file does not exist.".format(method)) encode_nc = {key: {"zlib": True, "complevel": comp_level} for key in ds.keys()} - + # populate data in netCDF-file (take care for the mode!) try: - ds.to_netcdf(nc_fname, encoding=encode_nc) + ds.to_netcdf(nc_fname, encoding=encode_nc,engine="netcdf4") print("%{0}: netCDF-file '{1}' was created successfully.".format(method, nc_fname)) except Exception as err: print("%{0}: Something unexpected happened when creating netCDF-file '1'".format(method, nc_fname)) @@ -1159,7 +1180,7 @@ class Postprocess(TrainModel): if dtype is None: dtype = np.double else: - if not np.issubdtype(dtype, np.dtype(float).type): + if not np.issubdtype(dtype, np.number): raise ValueError("%{0}: dytpe must be a NumPy datatype, but is '{1}'".format(method, np.dtype(dtype))) if ds_preexist is None: @@ -1232,7 +1253,7 @@ def main(): parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use") parser.add_argument("--seed", type=int, default=7) parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+", - default=("mse", "psnr", "ssim", "acc"), + default=("mse", "psnr", "ssim", "acc", "texture"), help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.") parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0, help="Channel which is used for evaluation.") @@ -1261,7 +1282,7 @@ def main(): "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(chp)) # initialize postprocessing instance - postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, mode="test", + postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, data_mode="test", batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick) diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py index ed9e15e184a6b2944fc5f2c35b5ea47132fb5a28..99d7ac163883cbea2da2ab2ad1da156ebc2b5ff1 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py @@ -13,34 +13,30 @@ from tensorflow.contrib.training import HParams class BaseVideoDataset(object): - def __init__(self, input_dir, mode='train', num_epochs=None, seed=None, + def __init__(self, input_dir: str, mode: str = "train", num_epochs: int = None, seed: int = None, hparams_dict=None, hparams=None): """ - Args: - input_dir: either a directory containing subdirectories train, - val, test, etc, or a directory containing the tfrecords. - mode: either train, val, or test - num_epochs: if None, dataset is iterated indefinitely. - seed: random seed for the op that samples subsequences. - hparams_dict: a dict of `name=value` pairs, where `name` must be - defined in `self.get_default_hparams()`. - hparams: a string of comma separated list of `name=value` pairs, - where `name` must be defined in `self.get_default_hparams()`. - These values overrides any values in hparams_dict (if any). - Note: - self.input_dir is the directory containing the tfrecords. + This class is used for preparing data for training/validation and test models. + :param input_dir: the path of tfrecords files + :param mode: "train","val" or "test" + :param num_epochs: number of epochs + :param seed: the seed for dataset + :param hparams_dict: a dict of `name=value` pairs, where `name` must be defined in `self.get_default_hparams()`. + :param hparams: a dict of `name=value` pairs where `name` must be defined in `self.get_default_hparams()`. + These values overrides any values in hparams_dict (if any). """ + method = self.__class__.__name__ self.input_dir = os.path.normpath(os.path.expanduser(input_dir)) self.mode = mode self.num_epochs = num_epochs self.seed = seed - + self.shuffled = False # will be set properly in make_dataset-method + # sanity checks if self.mode not in ('train', 'val', 'test'): - raise ValueError('Invalid mode %s' % self.mode) - + raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode)) if not os.path.exists(self.input_dir): - raise FileNotFoundError("input_dir %s does not exist" % self.input_dir) + raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir)) self.filenames = None # look for tfrecords in input_dir and input_dir/mode directories for input_dir in [self.input_dir, os.path.join(self.input_dir, self.mode)]: @@ -57,13 +53,6 @@ class BaseVideoDataset(object): self.action_like_names_and_shapes = OrderedDict() self.hparams = self.parse_hparams(hparams_dict, hparams) - #Bing: add this for anomaly -# if os.path.exists(input_dir+"_mean"): -# input_mean_dir = input_dir+"_mean" -# self.filenames_mean = sorted(glob.glob(os.path.join(input_mean_dir, '*.tfrecord*'))) -# else: -# self.filenames_mean = None - def get_default_hparams_dict(self): """ @@ -134,14 +123,13 @@ class BaseVideoDataset(object): Parses a single tf.train.Example or tf.train.SequenceExample into images, states, actions, etc tensors. """ - - raise NotImplementedError def make_dataset(self, batch_size): filenames = self.filenames shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) if shuffle: + self.shuffled = True random.shuffle(filenames) dataset = tf.data.TFRecordDataset(filenames, buffer_size= 8 * 1024 * 1024) #todo: what is buffer_size @@ -167,7 +155,6 @@ class BaseVideoDataset(object): iterator = dataset.make_one_shot_iterator() return iterator.get_next() - def decode_and_preprocess_images(self, image_buffers, image_shape): def decode_and_preprocess_image(image_buffer): print("image buffer", tf.shape(image_buffer)) @@ -258,7 +245,6 @@ class BaseVideoDataset(object): raise NotImplementedError - class VideoDataset(BaseVideoDataset): """ This class supports reading tfrecords where a sequence is stored as diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index 1a46a99dcd1b7918f42b609d96588b3d528fb000..eb69a74045ffad93502afbdb1aac8fa20b593294 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -5,11 +5,11 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth" -import argparse import os import glob import random import json +import numpy as np import tensorflow as tf from collections import OrderedDict from tensorflow.contrib.training import HParams @@ -18,26 +18,37 @@ from general_utils import reduce_dict class ERA5Dataset(object): - def __init__(self,input_dir=None,datasplit_config=None,hparams_dict_config=None, mode='train',seed=None): + def __init__(self, input_dir: str = None, datasplit_config: str = None, hparams_dict_config: str = None, + mode: str = "train", seed: int = None, nsamples_ref: int = None): """ This class is used for preparing data for training/validation and test models - args: - input_dir : the path of tfrecords files - datasplit_config : the path pointing to the datasplit_config json file - hparams_dict_config : the path to the dict that contains hparameters, - mode : string, "train","val" or "test" - seed : int, the seed for dataset - """ - # super(ERA5Dataset, self).__init__(**kwargs) + :param input_dir: the path of tfrecords files + :param datasplit_config: the path pointing to the datasplit_config json file + :param hparams_dict_config: the path to the dict that contains hparameters, + :param mode: string, "train","val" or "test" + :param seed: int, the seed for dataset + :param nsamples_ref: number of reference samples whch can be used to control repetition factor for dataset + for ensuring adopted size of dataset iterator (used for validation data during training) + Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then + the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref) + """ + method = self.__class__.__name__ + self.input_dir = input_dir self.datasplit_config = datasplit_config self.mode = mode self.seed = seed self.sequence_length = None # will be set in get_example_info + self.nsamples_ref = None + self.shuffled = False # will be set properly in make_dataset-method + # sanity checks if self.mode not in ('train', 'val', 'test'): - raise ValueError('Invalid mode %s' % self.mode) + raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode)) if not os.path.exists(self.input_dir): - raise FileNotFoundError("input_dir %s does not exist" % self.input_dir) + raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir)) + if nsamples_ref is not None: + self.nsamples_ref = nsamples_ref + # get configuration parameters from datasplit- and modelparameters-files self.datasplit_dict_path = datasplit_config self.data_dict = self.get_datasplit() self.hparams_dict_config = hparams_dict_config @@ -59,7 +70,6 @@ class ERA5Dataset(object): def get_default_hparams(self): return HParams(**self.get_default_hparams_dict()) - def get_default_hparams_dict(self): """ Provide dictionary containing default hyperparameters for the dataset @@ -72,9 +82,9 @@ class ERA5Dataset(object): """ hparams = dict( context_frames=10, - max_epochs = 20, - batch_size = 40, - shuffle_on_val= True, + max_epochs=20, + batch_size=40, + shuffle_on_val=True, ) return hparams @@ -84,8 +94,8 @@ class ERA5Dataset(object): """ with open(self.datasplit_dict_path) as f: - self.d = json.load(f) - return self.d + datasplit_dict = json.load(f) + return datasplit_dict def parse_hparams(self): """ @@ -96,7 +106,6 @@ class ERA5Dataset(object): return parsed_hparams - def get_tfrecords_filesnames_base_datasplit(self): """ Get absolute .tfrecord path names based on the data splits patterns @@ -116,7 +125,6 @@ class ERA5Dataset(object): if not self.filenames: raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir) - def get_example_info(self): """ Get the data information from an example tfrecord file @@ -140,9 +148,9 @@ class ERA5Dataset(object): with open(num_seq_file, 'r') as dfile: num_seqs = dfile.readlines() num_sequences = [int(num_seq.strip()) for num_seq in num_seqs] - self.num_examples_per_epoch = len_fnames * num_sequences[0] - return self.num_examples_per_epoch + num_examples_per_epoch = len_fnames * num_sequences[0] + return num_examples_per_epoch def make_dataset(self, batch_size): """ @@ -153,7 +161,10 @@ class ERA5Dataset(object): args: batch_size: int, the size of samples fed into the models per iteration """ + method = ERA5Dataset.make_dataset.__name__ + self.num_epochs = self.hparams.max_epochs + def parser(serialized_example): seqs = OrderedDict() keys_to_features = { @@ -178,15 +189,20 @@ class ERA5Dataset(object): filenames = self.filenames shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) if shuffle: + self.shuffled = True random.shuffle(filenames) - dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) - #dataset = dataset.filter(self.filter) + dataset = tf.data.TFRecordDataset(filenames, buffer_size=8*1024*1024) + + # set-up dataset iterator + nrepeat = self.num_epochs + if self.nsamples_ref: + num_samples = self.num_examples_per_epoch() + nrepeat = int(nrepeat*max(int(np.ceil(self.nsamples_ref/num_samples)), 1)) + if shuffle: - dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs)) + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=nrepeat)) else: - dataset = dataset.repeat(self.num_epochs) - - if self.mode == "val": dataset = dataset.repeat(20) + dataset = dataset.repeat(nrepeat) num_parallel_calls = None if shuffle else 1 dataset = dataset.apply(tf.contrib.data.map_and_batch( @@ -200,8 +216,7 @@ class ERA5Dataset(object): return iterator.get_next() - - +# further auxiliary methods def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) @@ -209,10 +224,10 @@ def _bytes_feature(value): def _bytes_list_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + def _floats_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - - diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py index db1417e7f19a41757f57b2462fad1b8d3be6de95..1c29fe5fa11c406f8de8c60fcd29d4bc8de60e10 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py @@ -45,11 +45,9 @@ class KTHVideoDataset(object): self.get_example_info() - def get_default_hparams(self): return HParams(**self.get_default_hparams_dict()) - def get_default_hparams_dict(self): """ The function that contains default hparams @@ -72,9 +70,6 @@ class KTHVideoDataset(object): ) return hparams - - - def get_datasplit(self): """ Get the datasplit json file @@ -171,7 +166,6 @@ def save_tf_record(output_fname, sequences): writer.write(example.SerializeToString()) - def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128): partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test sequences = [] diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py index b3b2e63baffeed58978543dc283788503b1197be..45a51248592e5a94ff951e00a143a9fcd6abc482 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py @@ -6,12 +6,11 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong, Karim" __date__ = "2021-05-03" - - import glob import os import random import json +import numpy as np import tensorflow as tf from tensorflow.contrib.training import HParams from collections import OrderedDict @@ -19,24 +18,34 @@ from google.protobuf.json_format import MessageToDict class MovingMnist(object): - def __init__(self, input_dir=None, datasplit_config=None, hparams_dict_config=None, mode="train",seed=None): - """ - This class is used for preparing the data for moving mnist, and split the data to train/val/testing - :params input_dir: the path of tfrecords files - :params datasplit_config: the path pointing to the datasplit_config json file - :params hparams_dict_config: the path to the dict that contains hparameters - :params mode: string, "train","val" or "test" - :params seed:int, the seed for dataset - :return None - """ + def __init__(self, input_dir: str = None, datasplit_config: str = None, hparams_dict_config: str = None, + mode: str = "train", seed: int = None, nsamples_ref: int = None): + """ + This class is used for preparing data for training/validation and test models + :param input_dir: the path of tfrecords files + :param datasplit_config: the path pointing to the datasplit_config json file + :param hparams_dict_config: the path to the dict that contains hparameters, + :param mode: string, "train","val" or "test" + :param seed: int, the seed for dataset + :param nsamples_ref: number of reference samples whch can be used to control repetition factor for dataset + for ensuring adopted size of dataset iterator (used for validation data during training) + Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then + the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref) + """ + method = self.__class__.__name__ + self.input_dir = input_dir self.mode = mode self.seed = seed self.sequence_length = None # will be set in get_example_info + self.shuffled = False # will be set properly in make_dataset-method + # sanity checks if self.mode not in ('train', 'val', 'test'): - raise ValueError('Invalid mode %s' % self.mode) + raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode)) if not os.path.exists(self.input_dir): - raise FileNotFoundError("input_dir %s does not exist" % self.input_dir) + raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir)) + if nsamples_ref is not None: + self.nsamples_ref = nsamples_ref self.datasplit_dict_path = datasplit_config self.data_dict = self.get_datasplit() self.hparams_dict_config = hparams_dict_config @@ -50,8 +59,8 @@ class MovingMnist(object): Get the datasplit json file """ with open(self.datasplit_dict_path) as f: - self.d = json.load(f) - return self.d + datasplit_dict = json.load(f) + return datasplit_dict def get_model_hparams_dict(self): """ @@ -62,7 +71,6 @@ class MovingMnist(object): with open(self.hparams_dict_config) as f: self.model_hparams_dict_load.update(json.loads(f.read())) return self.model_hparams_dict_load - def parse_hparams(self): """ @@ -74,9 +82,7 @@ class MovingMnist(object): def get_default_hparams(self): return HParams(**self.get_default_hparams_dict()) - def get_default_hparams_dict(self): - """ The function that contains default hparams Returns: @@ -91,15 +97,14 @@ class MovingMnist(object): hparams = dict( context_frames=10, sequence_length=20, - max_epochs = 20, - batch_size = 40, - lr = 0.001, - loss_fun = "rmse", - shuffle_on_val= True, + max_epochs=20, + batch_size=40, + lr=0.001, + loss_fun="rmse", + shuffle_on_val=True, ) return hparams - def get_tfrecords_filename_base_datasplit(self): """ Get obsoluate .tfrecords names based on the data splits patterns @@ -121,12 +126,11 @@ class MovingMnist(object): if not self.filenames: raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir) - @staticmethod def string_filter(max_value=None, min_value=None, string="input_directory/sequence_index_0_index_10.tfrecords"): a = os.path.split(string)[-1].split("_") if not len(a) == 5: - raise ("The tfrecords pattern does not match the expected pattern, for instanct: 'sequence_index_0_to_10.tfrecords'") + raise ("The tfrecords pattern does not match the expected pattern, for instance: 'sequence_index_0_to_10.tfrecords'") min_index = int(a[2]) max_index = int(a[4].split(".")[0]) if min_index >= min_value and max_index <= max_value: @@ -157,10 +161,9 @@ class MovingMnist(object): with open(num_seq_file, 'r') as dfile: num_seqs = dfile.readlines() num_sequences = [int(num_seq.strip()) for num_seq in num_seqs] - self.num_examples_per_epoch = len_fnames * num_sequences[0] - - return self.num_examples_per_epoch + num_examples_per_epoch = len_fnames * num_sequences[0] + return num_examples_per_epoch def make_dataset(self, batch_size): """ @@ -171,7 +174,10 @@ class MovingMnist(object): args: batch_size: int, the size of samples fed into the models per iteration """ + method = MovingMnist.make_dataset.__name__ + self.num_epochs = self.hparams.max_epochs + def parser(serialized_example): seqs = OrderedDict() keys_to_features = { @@ -192,13 +198,19 @@ class MovingMnist(object): filenames = self.filenames shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) if shuffle: + self.shuffled = True random.shuffle(filenames) - dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) + dataset = tf.data.TFRecordDataset(filenames, buffer_size=8*1024*1024) + # set-up dataset iterator + nrepeat = self.num_epochs + if self.nsamples_ref: + num_samples = self.num_examples_per_epoch() + nrepeat = int(nrepeat*max(int(np.ceil(self.nsamples_ref/num_samples)), 1)) + if shuffle: - dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count=self.num_epochs)) + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=nrepeat)) else: - dataset = dataset.repeat(self.num_epochs) - if self.mode == "val": dataset = dataset.repeat(20) + dataset = dataset.repeat(nrepeat) num_parallel_calls = None if shuffle else 1 dataset = dataset.apply(tf.contrib.data.map_and_batch( parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) @@ -210,6 +222,8 @@ class MovingMnist(object): iterator = dataset.make_one_shot_iterator() return iterator.get_next() + +# further auxiliary methods def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) @@ -217,9 +231,11 @@ def _bytes_feature(value): def _bytes_list_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + def _floats_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py index 2bc8a399a49a10f3df99f9646c040140970d573c..1857f8b915d62646dff9a73d63f16e78656ddc57 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/base_model.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: MIT +import functools import itertools import os import re diff --git a/video_prediction_tools/postprocess/statistical_evaluation.py b/video_prediction_tools/postprocess/statistical_evaluation.py index 965165a4afc6967e0cadce4ffd93da3a44f14dc0..960df504a5ddc921087c00286369f8cbe850e2ee 100644 --- a/video_prediction_tools/postprocess/statistical_evaluation.py +++ b/video_prediction_tools/postprocess/statistical_evaluation.py @@ -21,7 +21,7 @@ try: l_tqdm = True except: l_tqdm = False -from general_utils import provide_default +from general_utils import provide_default, check_str_in_list # basic data types da_or_ds = Union[xr.DataArray, xr.Dataset] @@ -107,7 +107,7 @@ def avg_metrics(metric: da_or_ds, dim_name: str): :return: DataArray or Dataset of metric averaged over given dimension. If a Dataset is passed, the averaged metrics carry the suffix "_avg" in their variable names. """ - method = perform_block_bootstrap_metric.__name__ + method = avg_metrics.__name__ if not isinstance(metric, da_or_ds.__args__): raise ValueError("%{0}: Input metric must be a xarray DataArray or Dataset and not {1}".format(method, @@ -205,9 +205,6 @@ class Scores: """ Class to calculate scores and skill scores. """ - - known_scores = ["mse", "psnr", "ssim", "acc"] - def __init__(self, score_name: str, dims: List[str]): """ Initialize score instance. @@ -216,9 +213,9 @@ class Scores: :return: Score instance """ method = Scores.__init__.__name__ - self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch, "ssim":self.calc_ssim_batch, "acc":self.calc_acc_batch} - if set(self.metrics_dict.keys()) != set(Scores.known_scores): - raise ValueError("%{0}: Known scores must coincide with keys of metrics_dict.".format(method)) + self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch, + "ssim": self.calc_ssim_batch, "acc": self.calc_acc_batch, + "texture": self.calc_spatial_variability} self.score_name = self.set_score_name(score_name) self.score_func = self.metrics_dict[score_name] # attributes set when run_calculation is called @@ -291,10 +288,10 @@ class Scores: method = Scores.calc_ssim_batch.__name__ batch_size = np.array(data_ref).shape[0] fore_hours = np.array(data_fcst).shape[1] - ssim_pred = [[ssim(data_ref[i,j,:,:],data_fcst[i,j,:,:]) for j in range(fore_hours)] for i in range(batch_size)] + ssim_pred = [[ssim(data_ref[i,j, ...],data_fcst[i,j,...]) for j in range(fore_hours)] + for i in range(batch_size)] return ssim_pred - def calc_acc_batch(self, data_fcst, data_ref, **kwargs): """ Calculate acc ealuation metric of forecast data w.r.t reference data @@ -309,21 +306,15 @@ class Scores: else: raise KeyError("%{0}: climatological data must be parsed to calculate the ACC.".format(method)) - #print(data_fcst) - #print('data_clim shape: ',data_clim.shape) batch_size = data_fcst.shape[0] fore_hours = data_fcst.shape[1] - #print('batch_size: ',batch_size) - #print('fore_hours: ',fore_hours) acc = np.ones([batch_size,fore_hours])*np.nan for i in range(batch_size): for j in range(fore_hours): - img_fcst = data_fcst[i,j,:,:] - img_ref = data_ref[i,j,:,:] + img_fcst = data_fcst[i, j, ...] + img_ref = data_ref[i, j, ...] # get the forecast time - print('img_fcst.init_time: ',img_fcst.init_time) fcst_time = xr.Dataset({'time': pd.to_datetime(img_fcst.init_time.data) + datetime.timedelta(hours=j)}) - print('fcst_time: ',fcst_time.time) img_month = fcst_time.time.dt.month img_hour = fcst_time.time.dt.hour img_clim = data_clim.sel(month=img_month, hour=img_hour) @@ -336,5 +327,94 @@ class Scores: img2_ = img_fcst - img_clim cor1 = np.sum(img1_*img2_) cor2 = np.sqrt(np.sum(img1_**2)*np.sum(img2_**2)) - acc[i,j] = cor1/cor2 + acc[i, j] = cor1/cor2 return acc + + def calc_spatial_variability(self, data_fcst, data_ref, **kwargs): + """ + Calculates the ratio between the spatial variability of differental operator with order 1 (or 2) forecast and + reference data + :param data_fcst: data_fcst: forecasted data (xarray with dimensions [batch, fore_hours, lat, lon]) + :param data_ref: reference data (xarray with dimensions [batch, fore_hours, lat, lon]) + :param kwargs: order to control the order of spatial differential operator, 'non_spatial_avg_dims' to perform + averaging + :return: the ratio between spatial variabilty in the forecast and reference data field + """ + + method = Scores.calc_spatial_variability.__name__ + + if self.avg_dims is None: + pass + else: + print("%{0}: Passed dimensions to Scores-object instance are ignored.".format(method) + + "Make use of 'non_spatial_avg_dims' to pass a list over dimensions for averaging") + + if "order" in kwargs: + order = kwargs.get("order") + else: + order = 1 + + if "non_spatial_avg_dims" in kwargs: + add_avg_dims = kwargs.get("non_spatial_avg_dims") + else: + add_avg_dims = None + + fcst_grad = Scores.calc_geo_spatial_diff(data_fcst, order=order) + ref_grd = Scores.calc_geo_spatial_diff(data_ref, order=order) + + ratio_spat_variability = fcst_grad/ref_grd + + if add_avg_dims: ratio_spat_variability = ratio_spat_variability.mean(dim=add_avg_dims) + + return ratio_spat_variability + + @staticmethod + def calc_geo_spatial_diff(scalar_field: xr.DataArray, order: int = 1, r_e: float = 6371.e3, avg_dom: bool = True): + """ + Calculates the amplitude of the gradient (order=1) or the Laplacian (order=2) of a scalar field given on a regular, + geographical grid (i.e. dlambda = const. and dphi=const.) + :param scalar_field: scalar field as data array with latitude and longitude as coordinates + :param order: order of spatial differential operator + :param r_e: radius of the sphere + :param avg_dom: flag if amplitude is averaged over the domain + :return: the amplitude of the gradient/laplacian at each grid point or over the whole domain (see avg_dom) + """ + method = Scores.calc_geo_spatial_diff.__name__ + + # sanity checks + assert isinstance(scalar_field, xr.DataArray), "%{0}: scalar_field must be a xarray DataArray."\ + .format(method) + assert order in [1, 2], "%{0}: Order must be either 1 or 2.".format(method) + + dims = list(scalar_field.dims) + lat_dims = ["lat", "latitude"] + lon_dims = ["lon", "longitude"] + + def check_for_coords(coord_names_data, coord_names_expected): + for coord in coord_names_expected: + stat, ind_coord = check_str_in_list(coord_names_data, coord, return_ind=True) + if stat: + return ind_coord[0], coord_names_data[ind_coord[0]] # just take the first value + + raise ValueError("%{0}: Could not find one of the following coordinates in the passed dictionary." + .format(method, ",".join(coord_names_expected))) + + lat_ind, lat_name = check_for_coords(dims, lat_dims) + lon_ind, lon_name = check_for_coords(dims, lon_dims) + + lat, lon = np.deg2rad(scalar_field[lat_name]), np.deg2rad(scalar_field[lon_name]) + + dphi, dlambda = lat[1].values - lat[0].values, lon[1].values - lon[0].values + + if order == 1: + dvar_dlambda = 1./(r_e*np.cos(lat)*np.deg2rad(dlambda))*scalar_field.differentiate(lon_name) + dvar_dphi = 1./(r_e*np.deg2rad(dphi))*scalar_field.differentiate(lat_name) + dvar_dlambda = dvar_dlambda.transpose(*scalar_field.dims) # ensure that dimension ordering is not changed + + var_diff_amplitude = np.sqrt(dvar_dlambda**2 + dvar_dphi**2) + if avg_dom: var_diff_amplitude = var_diff_amplitude.mean(dim=[lat_name, lon_name]) + else: + raise ValueError("%{0}: Second-order differentation is not implemenetd yet.".format(method)) + + return var_diff_amplitude + diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py index d18c9d11000df9cb73b0d41ffde3f5ece518982c..f4031cee83fc27bd14beaedd455f32e9b86a42fd 100644 --- a/video_prediction_tools/utils/general_utils.py +++ b/video_prediction_tools/utils/general_utils.py @@ -12,6 +12,7 @@ Provides: * get_unique_vars * check_str_in_list * check_dir * reduce_dict + * find_key * provide_default """ @@ -101,7 +102,6 @@ def isw(value, interval): :param interval: The interval defined by lower and upper bound :return status: True if value lies in interval """ - method = isw.__name__ if np.shape(interval)[0] != 2: @@ -137,8 +137,8 @@ def check_str_in_list(list_in: List, str2check: str_or_List, labort: bool = True if isinstance(str2check, str): str2check = [str2check] elif isinstance(str2check, list): - assert np.all([isinstance(str1, str) for str1 in str2check]) == True, \ - "Not all elements of str2check are strings" + assert np.all([isinstance(str1, str) for str1 in str2check]), "Not all elements of str2check are strings"\ + .format(method) else: raise ValueError("%{0}: str2check argument must be either a string or a list of strings".format(method))