Select Git revision
X_train_tile_4_tiny.npy
main_visualize_postprocess.py 62.41 KiB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong, Yan Ji, Michael Langguth"
__date__ = "2020-11-10"
import argparse
import os
import shutil
import numpy as np
import xarray as xr
import pandas as pd
import tensorflow as tf
import pickle
import datetime as dt
import json
from typing import Union, List
# own modules
from general_utils import get_era5_varatts
from normalization import Norm_data
from general_utils import check_dir
from metadata import MetaData as MetaData
from main_scripts.main_train_models import *
from data_preprocess.preprocess_data_step2 import *
from model_modules.video_prediction import datasets, models, metrics
from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, calculate_cond_quantiles, Scores
from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, create_geo_contour_plot
class Postprocess(TrainModel):
def __init__(self, results_dir=None, checkpoint=None, mode="test", batch_size=None, num_stochastic_samples=1,
stochastic_plot_id=0, gpu_mem_frac=None, seed=None, channel=0, args=None, run_mode="deterministic",
eval_metrics=None):
"""
Initialization of the class instance for postprocessing (generation of forecasts from trained model +
basic evauation).
:param results_dir: output directory to save results
:param checkpoint: directory point to the model checkpoints
:param mode: mode of dataset to be processed ("train", "val" or "test"), default: "test"
:param batch_size: mini-batch size for generating forecasts from trained model
:param num_stochastic_samples: number of ensemble members for variational models (SAVP, VAE), default: 1
not supported yet!!!
:param stochastic_plot_id: not supported yet!
:param gpu_mem_frac: fraction of GPU memory to be pre-allocated
:param seed: Integer controlling randomization
:param channel: Channel of interest for statistical evaluation
:param args: namespace of parsed arguments
:param run_mode: "deterministic" or "stochastic", default: "deterministic", "stochastic is not supported yet!!!
:param eval_metrics: metrics used to evaluate the trained model
"""
# copy over attributes from parsed argument
self.results_dir = self.output_dir = os.path.normpath(results_dir)
_ = check_dir(self.results_dir, lcreate=True)
self.batch_size = batch_size
self.gpu_mem_frac = gpu_mem_frac
self.seed = seed
self.set_seed()
self.num_stochastic_samples = num_stochastic_samples
self.stochastic_plot_id = stochastic_plot_id
self.args = args
self.checkpoint = checkpoint
_ = check_dir(self.checkpoint)
self.run_mode = run_mode
self.mode = mode
self.channel = channel
# Attributes set during runtime
self.norm_cls = None
# configuration of basic evaluation
self.eval_metrics = eval_metrics
self.nboots_block = 1000
self.block_length = 5 #* 24 # this corresponds to a block length of 7 days in case of hourly forecasts
# initialize evrything to get an executable Postprocess instance
self.save_args_to_option_json() # create options.json-in results directory
self.copy_data_model_json() # copy over JSON-files from model directory
# get some parameters related to model and dataset
self.datasplit_dict, self.model_hparams_dict, self.dataset, self.model, self.input_dir_tfr = self.load_jsons()
self.model_hparams_dict_load = self.get_model_hparams_dict()
# set input paths and forecast product dictionary
self.input_dir, self.input_dir_pkl = self.get_input_dirs()
self.fcst_products = {"persistence": "pfcst", self.model: "mfcst"}
# correct number of stochastic samples if necessary
self.check_num_stochastic_samples()
# get metadata
md_instance = self.get_metadata()
self.height, self.width = md_instance.ny, md_instance.nx
self.vars_in = md_instance.variables
self.lats, self.lons = md_instance.get_coord_array()
# get statistics JSON-file
self.stat_fl = self.set_stat_file()
self.cond_quantile_vars = self.init_cond_quantile_vars()
# setup test dataset and model
self.test_dataset, self.num_samples_per_epoch = self.setup_test_dataset()
self.num_samples_per_epoch = 100
self.sequence_length, self.context_frames, self.future_length = self.get_data_params()
self.inputs, self.input_ts = self.make_test_dataset_iterator()
# set-up model, its graph and do GPU-configuration (from TrainModel)
self.setup_model()
self.setup_graph()
self.setup_gpu_config()
# Methods that are called during initialization
def get_input_dirs(self):
"""
Retrieves top-level input directory and nested pickle-directory from input_dir_tfr
:return input_dir: top-level input-directoy
:return input_dir_pkl: Input directory where pickle-files are placed
"""
method = Postprocess.get_input_dirs.__name__
if not hasattr(self, "input_dir_tfr"):
raise AttributeError("Attribute input_dir_tfr is still missing.".format(method))
_ = check_dir(self.input_dir_tfr)
input_dir = os.path.dirname(self.input_dir_tfr.rstrip("/"))
input_dir_pkl = os.path.join(input_dir, "pickle")
_ = check_dir(input_dir_pkl)
return input_dir, input_dir_pkl
# methods that are executed with __call__
def save_args_to_option_json(self):
"""
Save the argments defined by user to the results dir
"""
with open(os.path.join(self.results_dir, "options.json"), "w") as f:
f.write(json.dumps(vars(self.args), sort_keys=True, indent=4))
def copy_data_model_json(self):
"""
Copy relevant JSON-files from checkpoints directory to results_dir
"""
method_name = Postprocess.copy_data_model_json.__name__
# correctness of self.checkpoint and self.results_dir is already checked in __init__
model_opt_js = os.path.join(self.checkpoint, "options.json")
model_ds_js = os.path.join(self.checkpoint, "dataset_hparams.json")
model_hp_js = os.path.join(self.checkpoint, "model_hparams.json")
model_dd_js = os.path.join(self.checkpoint, "data_dict.json")
if os.path.isfile(model_opt_js):
shutil.copy(model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json"))
else:
raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_opt_js))
if os.path.isfile(model_ds_js):
shutil.copy(model_ds_js, os.path.join(self.results_dir, "dataset_hparams.json"))
else:
raise FileNotFoundError("%{0}: the file {1} does not exist".format(method_name, model_ds_js))
if os.path.isfile(model_hp_js):
shutil.copy(model_hp_js, os.path.join(self.results_dir, "model_hparams.json"))
else:
raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_hp_js))
if os.path.isfile(model_dd_js):
shutil.copy(model_dd_js, os.path.join(self.results_dir, "data_dict.json"))
else:
raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_dd_js))
def load_jsons(self):
"""
Set attributes pointing to JSON-files which track essential information and also load some information
to store it to attributes of the class instance
:return datasplit_dict: path to datasplit-dictionary JSON-file of trained model
:return model_hparams_dict: path to model hyperparameter-dictionary JSON-file of trained model
:return dataset: Name of datset used to train model
:return model: Name of trained model
:return input_dir_tfr: path to input directory where TF-records are stored
"""
method_name = Postprocess.load_jsons.__name__
datasplit_dict = os.path.join(self.results_dir, "data_dict.json")
model_hparams_dict = os.path.join(self.results_dir, "model_hparams.json")
checkpoint_opt_dict = os.path.join(self.results_dir, "options_checkpoints.json")
# sanity checks on the JSON-files
if not os.path.isfile(datasplit_dict):
raise FileNotFoundError("%{0}: The file data_dict.json is missing in {1}".format(method_name,
self.results_dir))
if not os.path.isfile(model_hparams_dict):
raise FileNotFoundError("%{0}: The file model_hparams.json is missing in {1}".format(method_name,
self.results_dir))
if not os.path.isfile(checkpoint_opt_dict):
raise FileNotFoundError("%{0}: The file options_checkpoints.json is missing in {1}"
.format(method_name, self.results_dir))
# retrieve some data from options_checkpoints.json
try:
with open(checkpoint_opt_dict) as f:
options_checkpoint = json.loads(f.read())
dataset = options_checkpoint["dataset"]
model = options_checkpoint["model"]
input_dir_tfr = options_checkpoint["input_dir"]
except Exception as err:
print("%{0}: Something went wrong when reading the checkpoint-file '{1}'".format(method_name,
checkpoint_opt_dict))
raise err
return datasplit_dict, model_hparams_dict, dataset, model, input_dir_tfr
def get_metadata(self):
method_name = Postprocess.get_metadata.__name__
# some sanity checks
if self.input_dir is None:
raise AttributeError("%{0}: input_dir-attribute is still None".format(method_name))
metadata_fl = os.path.join(self.input_dir, "metadata.json")
if not os.path.isfile(metadata_fl):
raise FileNotFoundError("%{0}: Could not find metadata JSON-file under '{1}'".format(method_name,
self.input_dir))
try:
md_instance = MetaData(json_file=metadata_fl)
except Exception as err:
print("%{0}: Something went wrong when getting metadata from file '{1}'".format(method_name, metadata_fl))
raise err
return md_instance
def setup_test_dataset(self):
"""
setup the test dataset instance
:return test_dataset: the test dataset instance
"""
VideoDataset = datasets.get_dataset_class(self.dataset)
test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.mode, datasplit_config=self.datasplit_dict)
nsamples = test_dataset.num_examples_per_epoch()
return test_dataset, nsamples
def get_data_params(self):
"""
Get the context_frames, future_frames and total frames from hparamters settings.
Note that future_frames_length is the number of predicted frames.
"""
method = Postprocess.get_data_params.__name__
if not hasattr(self, "model_hparams_dict_load"):
raise AttributeError("%{0}: Attribute model_hparams_dict_load is still unset.".format(method))
try:
context_frames = self.model_hparams_dict_load["context_frames"]
sequence_length = self.model_hparams_dict_load["sequence_length"]
except Exception as err:
print("%{0}: Could not retrieve context_frames and sequence_length from model_hparams_dict_load-attribute"
.format(method))
raise err
future_length = sequence_length - context_frames
if future_length <= 0:
raise ValueError("Calculated future_length must be greater than zero.".format(method))
return sequence_length, context_frames, future_length
def set_stat_file(self):
"""
Set the name of the statistic file from the input directory
:return stat_fl: Path to statistics JSON-file of input data used to train the model
"""
method = Postprocess.set_stat_file.__name__
if not hasattr(self, "input_dir"):
raise AttributeError("%{0}: Attribute input_dir is still unset".format(method))
stat_fl = os.path.join(self.input_dir, "statistics.json")
if not os.path.isfile(stat_fl):
raise FileNotFoundError("%{0}: Cannot find statistics JSON-file '{1}'".format(method, stat_fl))
return stat_fl
def init_cond_quantile_vars(self):
"""
Get a list of variable names for conditional quantile plot
:return cond_quantile_vars: list holding the variable names of interest
"""
method = Postprocess.init_cond_quantile_vars.__name__
if not hasattr(self, "model"):
raise AttributeError("%{0}: Attribute model is still unset.".format(method))
cond_quantile_vars = ["{0}_{1}_fcst".format(self.vars_in[self.channel], self.model),
"{0}_ref".format(self.vars_in[self.channel])]
return cond_quantile_vars
def make_test_dataset_iterator(self):
"""
Make the dataset iterator
"""
method = Postprocess.make_test_dataset_iterator.__name__
if not hasattr(self, "test_dataset"):
raise AttributeError("%{0}: Attribute test_dataset is still unset".format(method))
if not hasattr(self, "batch_size"):
raise AttributeError("%{0}: Attribute batch_sie is still unset".format(method))
test_tf_dataset = self.test_dataset.make_dataset(self.batch_size)
test_iterator = test_tf_dataset.make_one_shot_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
test_handle = test_iterator.string_handle()
dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types,
test_tf_dataset.output_shapes)
input_iter = dataset_iterator.get_next()
ts_iter = input_iter["T_start"]
return input_iter, ts_iter
def check_num_stochastic_samples(self):
"""
stochastic forecasting only suitable for the geneerate models such as SAVP, vae.
For convLSTM, McNet only do determinstic forecasting
"""
method = Postprocess.check_num_stochastic_samples.__name__
if not hasattr(self, "model"):
raise AttributeError("%{0}: Attribute model is still unset".format(method))
if not hasattr(self, "num_stochastic_samples"):
raise AttributeError("%{0}: Attribute num_stochastic_samples is still unset".format(method))
if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet':
if self.num_stochastic_samples > 1:
print("Number of samples for deterministic model cannot be larger than 1. Higher values are ignored.")
self.num_stochastic_samples = 1
# the run-factory
def run(self):
if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet':
self.run_deterministic()
elif self.run_mode == "deterministic":
self.run_deterministic()
else:
self.run_stochastic()
def run_stochastic(self):
"""
Run session, save results to netcdf, plot input images, generate images and persistent images
"""
method = Postprocess.run_stochastic.__name__
raise ValueError("ML: %{0} is not runnable now".format(method))
self.init_session()
self.restore(self.sess, self.checkpoint)
# Loop for samples
self.sample_ind = 0
self.prst_metric_all = [] # store evaluation metrics of persistence forecast (shape [future_len])
self.fcst_metric_all = [] # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len])
while self.sample_ind < self.num_samples_per_epoch:
if self.num_samples_per_epoch < self.sample_ind:
break
else:
# run the inputs and plot each sequence images
self.input_results, self.input_images_denorm_all, self.t_starts = self.get_input_data_per_batch()
feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()}
gen_loss_stochastic_batch = [] # [stochastic_ind,future_length]
gen_images_stochastic = [] # [stochastic_ind,batch_size,seq_len,lat,lon,channels]
# Loop for stochastics
for stochastic_sample_ind in range(self.num_stochastic_samples):
print("stochastic_sample_ind:", stochastic_sample_ind)
# return [batchsize,seq_len,lat,lon,channel]
gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
# The generate images seq_len should be sequence_len -1, since the last one is
# not used for comparing with groud truth
assert gen_images.shape[1] == self.sequence_length - 1
gen_images_per_batch = []
if stochastic_sample_ind == 0:
persistent_images_per_batch = [] # [batch_size,seq_len,lat,lon,channel]
ts_batch = []
for i in range(self.batch_size):
# generate time stamps for sequences only once, since they are the same for all ensemble members
if stochastic_sample_ind == 0:
self.ts = Postprocess.generate_seq_timestamps(self.t_starts[i], len_seq=self.sequence_length)
init_date_str = self.ts[0].strftime("%Y%m%d%H")
ts_batch.append(init_date_str)
# get persistence_images
self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts,
self.input_dir_pkl)
persistent_images_per_batch.append(self.persistence_images)
assert len(np.array(persistent_images_per_batch).shape) == 5
self.plot_persistence_images()
# Denormalized data for generate
gen_images_ = gen_images[i]
self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_,
self.vars_in)
gen_images_per_batch.append(self.gen_images_denorm)
assert len(np.array(gen_images_per_batch).shape) == 5
# only plot when the first stochastic ind otherwise too many plots would be created
# only plot the stochastic results of user-defined ind
self.plot_generate_images(stochastic_sample_ind, self.stochastic_plot_id)
# calculate the persistnet error per batch
if stochastic_sample_ind == 0:
persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
persistent_images_per_batch,
self.future_length,
self.context_frames,
matric="mse", channel=0)
self.prst_metric_all.append(persistent_loss_per_batch)
# calculate the gen_images_per_batch error
gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
gen_images_per_batch, self.future_length,
self.context_frames,
matric="mse", channel=0)
gen_loss_stochastic_batch.append(
gen_loss_per_batch) # self.gen_images_stochastic[stochastic,future_length]
print("gen_images_per_batch shape:", np.array(gen_images_per_batch).shape)
gen_images_stochastic.append(
gen_images_per_batch) # [stochastic,batch_size, seq_len, lat, lon, channel]
# Switch the 0 and 1 position
print("before transpose:", np.array(gen_images_stochastic).shape)
gen_images_stochastic = np.transpose(np.array(gen_images_stochastic), (
1, 0, 2, 3, 4, 5)) # [batch_size, stochastic, seq_len, lat, lon, chanel]
Postprocess.check_gen_images_stochastic_shape(gen_images_stochastic)
assert len(gen_images_stochastic.shape) == 6
assert np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples
self.fcst_metric_all.append(
gen_loss_stochastic_batch) # [samples/batch_size,stochastic,future_length]
# save input and stochastic generate images to netcdf file
# For each prediction (either deterministic or ensemble) we create one netCDF file.
for batch_id in range(self.batch_size):
self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id],
persistent_images_per_batch[batch_id],
np.array(gen_images_stochastic)[batch_id],
fl_name="vfp_date_{}_sample_ind_{}.nc"
.format(ts_batch[batch_id],
self.sample_ind + batch_id))
self.sample_ind += self.batch_size
self.persistent_loss_all_batches = np.mean(np.array(self.persistent_loss_all_batches), axis=0)
self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0)
assert len(np.array(self.persistent_loss_all_batches).shape) == 1
assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length
print("Bug here:", np.array(self.stochastic_loss_all_batches).shape)
assert len(np.array(self.stochastic_loss_all_batches).shape) == 2
assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples
def run_deterministic(self):
"""
Revised and vectorized version of run_deterministic
Loops over the training data, generates forecasts and calculates basic evaluation metrics on-the-fly
"""
method = Postprocess.run_deterministic.__name__
# init the session and restore the trained model
self.init_session()
self.restore(self.sess, self.checkpoint)
# init sample index for looping
sample_ind = 0
nsamples = self.num_samples_per_epoch
# initialize xarray datasets
eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel],
nsamples, self.future_length)
cond_quantiple_ds = None
while sample_ind < self.num_samples_per_epoch:
# get normalized and denormalized input data
input_results, input_images_denorm, t_starts = self.get_input_data_per_batch(self.inputs)
# feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel]
feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()}
gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
# sanity check on length of forecast sequence
assert gen_images.shape[1] == self.sequence_length - 1, \
"%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method)
# denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method)
gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls,
norm_method="minmax")
# store data into datset & get number of samples (may differ from batch_size at the end of the test dataset)
times_0, init_times = self.get_init_time(t_starts)
batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times)
nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind)
batch_ds = batch_ds.isel(init_time=slice(0, nbs))
for i in np.arange(nbs):
# work-around to make use of get_persistence_forecast_per_sample-method
times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime()
# get persistence forecast for sequences at hand and write to dataset
persistence_seq, _ = Postprocess.get_persistence(times_seq, self.input_dir_pkl)
for ivar, var in enumerate(self.vars_in):
batch_ds["{0}_persistence_fcst".format(var)].loc[dict(init_time=init_times[i])] = \
persistence_seq[self.context_frames-1:, :, :, ivar]
# save sequences to netcdf-file and track initial time
nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc"
.format(pd.to_datetime(init_times[i]).strftime("%Y%m%d%H"), sample_ind + i))
self.save_ds_to_netcdf(batch_ds.isel(init_time=i), nc_fname)
# end of batch-loop
# write evaluation metric to corresponding dataset and sa
eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind,
self.vars_in[self.channel])
cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars, "init_time")
# ... and increment sample_ind
sample_ind += self.batch_size
# end of while-loop for samples
# safe dataset with evaluation metrics for later use
self.eval_metrics_ds = eval_metric_ds
self.cond_quantiple_ds = cond_quantiple_ds
#self.add_ensemble_dim()
# all methods of the run factory
def init_session(self):
"""
Initialize TensorFlow-session
:return: -
"""
method = Postprocess.init_session.__name__
if not hasattr(self, "config"):
raise AttributeError("Attribute config is still unset.".format(method))
self.sess = tf.Session(config=self.config)
self.sess.graph.as_default()
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.local_variables_initializer())
def get_input_data_per_batch(self, input_iter, norm_method="minmax"):
"""
Get the input sequence from the dataset iterator object stored in self.inputs and denormalize the data
:param input_iter: the iterator object built by make_test_dataset_iterator-method
:param norm_method: normalization method applicable to the data
:return input_results: the normalized input data
:return input_images_denorm: the denormalized input data
:return t_starts: the initial time of the sequences
"""
method = Postprocess.get_input_data_per_batch.__name__
input_results = self.sess.run(input_iter)
input_images = input_results["images"]
t_starts = input_results["T_start"]
if self.norm_cls is None:
if self.stat_fl is None:
raise AttributeError("%{0}: Attribute stat_fl is not initialized yet.".format(method))
self.norm_cls = Postprocess.get_norm(self.vars_in, self.stat_fl, norm_method)
# sanity check on input sequence
assert np.ndim(input_images) == 5, "%{0}: Input sequence of mini-batch does not have five dimensions."\
.format(method)
input_images_denorm = Postprocess.denorm_images_all_channels(input_images, self.vars_in, self.norm_cls,
norm_method=norm_method)
return input_results, input_images_denorm, t_starts
def get_init_time(self, t_starts):
"""
Retrieves initial dates of forecast sequences from start time of whole inpt sequence
:param t_starts: list/array of start times of input sequence
:return: list of initial dates of forecast as numpy.datetime64 instances
"""
method = Postprocess.get_init_time.__name__
t_starts = np.squeeze(np.asarray(t_starts))
if not np.ndim(t_starts) == 1:
raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H"
.format(method))
for i, t_start in enumerate(t_starts):
try:
seq_ts = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H"), periods=self.context_frames,
freq="h")
except Exception as err:
print("%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'".
format(method, str(t_start)))
raise err
if i == 0:
ts_all = np.expand_dims(seq_ts, axis=0)
else:
ts_all = np.vstack((ts_all, seq_ts))
init_times = ts_all[:, -1]
times0 = ts_all[:, 0]
return times0, init_times
def populate_eval_metric_ds(self, metric_ds, data_ds, ind_start, varname):
"""
Populates evaluation metric dataset with values
:param metric_ds: the evaluation metric dataset with variables such as 'mfcst_mse' (MSE of model forecast)
:param data_ds: dataset holding the data from one mini-batch (see create_dataset-method)
:param ind_start: start index of dimension init_time (part of metric_ds)
:param varname: variable of interest (must be part of self.vars_in)
:return: metric_ds
"""
method = Postprocess.populate_eval_metric_ds.__name__
# dictionary of implemented evaluation metrics
dims = ["lat", "lon"]
known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims)}
# generate list of functions that calculate requested evaluation metrics
if set(self.eval_metrics).issubset(known_eval_metrics.keys()):
eval_metrics_func = [known_eval_metrics[metric].score_func for metric in self.eval_metrics]
else:
misses = list(set(self.eval_metrics) - known_eval_metrics.keys())
raise NotImplementedError("%{0}: The following requested evaluation metrics are not implemented yet: "
.format(method, ", ".join(misses)))
varname_ref = "{0}_ref".format(varname)
# reset init-time coordinate of metric_ds in place and get indices for slicing
ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch)
init_times_metric = metric_ds["init_time"].values
init_times_metric[ind_start:ind_end] = data_ds["init_time"]
metric_ds = metric_ds.assign_coords(init_time=init_times_metric)
# populate metric_ds
for fcst_prod in self.fcst_products.keys():
for imetric, eval_metric in enumerate(self.eval_metrics):
metric_name = "{0}_{1}_{2}".format(varname, fcst_prod, eval_metric)
varname_fcst = "{0}_{1}_fcst".format(varname, fcst_prod)
dict_ind = dict(init_time=data_ds["init_time"])
metric_ds[metric_name].loc[dict_ind] = eval_metrics_func[imetric](data_ds[varname_fcst],
data_ds[varname_ref])
# end of metric-loop
# end of forecast product-loop
return metric_ds
def add_ensemble_dim(self):
"""
Expands dimensions of loss-arrays by dummy ensemble-dimension (used for deterministic forecasts only)
:return:
"""
self.stochastic_loss_all_batches = np.expand_dims(self.fcst_mse_avg_batches, axis=0) # [1,future_lenght]
self.stochastic_loss_all_batches_psnr = np.expand_dims(self.fcst_psnr_avg_batches, axis=0) # [1,future_lenght]
def create_dataset(self, input_seq, fcst_seq, ts_ini):
"""
Put input and forecast sequences into a xarray dataset. The latter also involves the persistence forecast
which is just initialized, but unpopulated at this stage.
The input data sequence is split into (effective) input sequence used for the forecast and into reference part.
:param input_seq: sequence of input images [batch ,seq, lat, lon, channel]
:param fcst_seq: sequence of forecast images [batch ,seq-1, lat, lon, channel]
:param ts_ini: initial time of forecast (=last time step of effective input sequence)
:return data_ds: above mentioned data in a nicely formatted dataset
"""
method = Postprocess.create_dataset.__name__
# auxiliary variables for temporal dimensions
seq_hours = np.arange(self.sequence_length) - (self.context_frames-1)
# some sanity checks
assert np.shape(ts_ini)[0] == self.batch_size,\
"%{0}: Inconsistent number of sequence start times ({1:d}) and batch size ({2:d})"\
.format(method, np.shape(ts_ini)[0], self.batch_size)
# turn input and forecast sequences to Data Arrays to ease indexing
try:
input_seq = xr.DataArray(input_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours,
"lat": self.lats, "lon": self.lons, "varname": self.vars_in},
dims=["init_time", "fcst_hour", "lat", "lon", "varname"])
except Exception as err:
print("%{0}: Could not create Data Array for input sequence.".format(method))
raise err
try:
fcst_seq = xr.DataArray(fcst_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours[1::],
"lat": self.lats, "lon": self.lons, "varname": self.vars_in},
dims=["init_time", "fcst_hour", "lat", "lon", "varname"])
except Exception as err:
print("%{0}: Could not create Data Array for forecast sequence.".format(method))
raise err
# Now create the dataset where the input sequence is splitted into input that served for creating the
# forecast and into the the reference sequences (which can be compared to the forecast)
# as where the persistence forecast is containing NaNs (must be generated later)
data_in_dict = dict([("{0}_in".format(var), input_seq.isel(fcst_hour=slice(None, self.context_frames),
varname=ivar)
.rename({"fcst_hour": "in_hour"})
.reset_coords(names="varname", drop=True))
for ivar, var in enumerate(self.vars_in)])
# get shape of forecast data (one variable) -> required to initialize persistence forecast data
shape_fcst = np.shape(fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=0)
.reset_coords(names="varname", drop=True))
data_ref_dict = dict([("{0}_ref".format(var), input_seq.isel(fcst_hour=slice(self.context_frames, None),
varname=ivar)
.reset_coords(names="varname", drop=True))
for ivar, var in enumerate(self.vars_in)])
data_mfcst_dict = dict([("{0}_{1}_fcst".format(var, self.model),
fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=ivar)
.reset_coords(names="varname", drop=True))
for ivar, var in enumerate(self.vars_in)])
# fill persistence forecast variables with dummy data (to be populated later)
data_pfcst_dict = dict([("{0}_persistence_fcst".format(var), (["init_time", "fcst_hour", "lat", "lon"],
np.full(shape_fcst, np.nan)))
for ivar, var in enumerate(self.vars_in)])
# create the dataset
data_ds = xr.Dataset({**data_in_dict, **data_ref_dict, **data_mfcst_dict, **data_pfcst_dict})
return data_ds
def handle_eval_metrics(self):
"""
Plots error-metrics averaged over all predictions to file.
:return: a bunch of plots as png-files
"""
method = Postprocess.handle_eval_metrics.__name__
if self.eval_metrics_ds is None:
raise AttributeError("%{0}: Attribute with dataset of evaluation metrics is still None.".format(method))
# perform bootstrapping on metric dataset
eval_metric_boot_ds = perform_block_bootstrap_metric(self.eval_metrics_ds, "init_time", self.block_length,
self.nboots_block)
# ... and merge into existing metric dataset
self.eval_metrics_ds = xr.merge([self.eval_metrics_ds, eval_metric_boot_ds])
# calculate (unbootstrapped) averaged metrics
eval_metric_avg_ds = avg_metrics(self.eval_metrics_ds, "init_time")
# ... and merge into existing metric dataset
self.eval_metrics_ds = xr.merge([self.eval_metrics_ds, eval_metric_avg_ds])
# save evaluation metrics to file
nc_fname = os.path.join(self.results_dir, "evaluation_metrics.nc")
Postprocess.save_ds_to_netcdf(self.eval_metrics_ds, nc_fname)
# also save averaged metrics to JSON-file and plot it for diagnosis
_ = plot_avg_eval_metrics(self.eval_metrics_ds, self.eval_metrics, self.fcst_products,
self.vars_in[self.channel], self.results_dir)
def plot_example_forecasts(self, metric="mse", channel=0):
"""
Plots example forecasts. The forecasts are chosen from the complete pool of the test dataset and are chosen
according to the accuracy in terms of the chosen metric. In add ition, to the best and worst forecast,
every decil of the chosen metric is retrieved to cover the whole bandwith of forecasts.
:param metric: The metric which is used for measuring accuracy
:param channel: The channel index of the forecasted variable to plot (correspondong to self.vars_in)
:return: 11 exemplary forecast plots are created
"""
method = Postprocess.plot_example_forecasts.__name__
metric_name = "{0}_{1}_{2}".format(self.vars_in[channel], self.model, metric)
if not metric_name in self.eval_metrics_ds:
raise ValueError("%{0}: Cannot find requested evaluation metric '{1}'".format(method, metric_name) +
" onto which selection of plotted forecast is done.")
# average metric of interest and obtain quantiles incl. indices
metric_mean = self.eval_metrics_ds[metric_name].mean(dim="fcst_hour")
quantiles = np.arange(0., 1.01, .1)
quantiles_val = metric_mean.quantile(quantiles, interpolation="nearest")
quantiles_inds = self.get_matching_indices(metric_mean.values, quantiles_val)
for i, ifcst in enumerate(quantiles_inds):
date_init = pd.to_datetime(metric_mean.coords["init_time"][ifcst].data)
nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc"
.format(date_init.strftime("%Y%m%d%H"), ifcst))
if not os.path.isfile(nc_fname):
raise FileNotFoundError("%{0}: Could not find requested file '{1}'".format(method, nc_fname))
else:
# get the data
varname = self.vars_in[channel]
with xr.open_dataset(nc_fname) as dfile:
data_fcst = dfile["{0}_{1}_fcst".format(varname, self.model)]
data_ref = dfile["{0}_ref".format(varname)]
data_diff = data_fcst - data_ref
# name of plot
plt_fname_base = os.path.join(self.output_dir, "forecast_{0}_{1}_{2}_{3:d}percentile.png"
.format(varname, date_init.strftime("%Y%m%dT%H00"), metric,
int(quantiles[i] * 100.)))
create_geo_contour_plot(data_fcst, data_diff, varname, plt_fname_base)
def plot_conditional_quantiles(self):
# release some memory
Postprocess.clean_obj_attribute(self, "eval_metrics_ds")
# the variables for conditional quantile plot
var_fcst = "{0}_{1}_fcst".format(self.vars_in[self.channel], self.model)
var_ref = "{0}_ref".format(self.vars_in[self.channel])
data_fcst = get_era5_varatts(self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name)
data_ref = get_era5_varatts(self.cond_quantiple_ds[var_ref], self.cond_quantiple_ds[var_ref].name)
# create plots
fhhs = data_fcst.coords["fcst_hour"]
for hh in fhhs:
# calibration refinement factorization
plt_fname_cf = os.path.join(self.results_dir, "cond_quantile_{0}_{1}_fh{2:0d}_calibration_refinement.png"
.format(self.vars_in[self.channel], self.model, int(hh)))
quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
data_ref.sel(fcst_hour=hh),
factorization="calibration_refinement",
quantiles=(0.05, 0.5, 0.95))
plot_cond_quantile(quantile_panel_cf, cond_variable_cf, plt_fname_cf)
# likelihood-base rate factorization
plt_fname_lbr = plt_fname_cf.replace("calibration_refinement", "likelihood-base_rate")
quantile_panel_lbr, cond_variable_lbr = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
data_ref.sel(fcst_hour=hh),
factorization="likelihood-base_rate",
quantiles=(0.05, 0.5, 0.95))
plot_cond_quantile(quantile_panel_lbr, cond_variable_lbr, plt_fname_lbr)
@staticmethod
def clean_obj_attribute(obj, attr_name, lremove=False):
"""
Cleans attribute of object by setting it to None (can be used to releave memory)
:param obj: the object/ class instance
:param attr_name: the attribute from the object to be cleaned
:param lremove: flag if attribute is removed or set to None
:return: the object/class instance with the attribute's value changed to None
"""
method = Postprocess.clean_obj_attribute.__name__
if not hasattr(obj, attr_name):
print("%{0}: Class attribute '{1}' does not exist. Nothing to do...".format(method, attr_name))
else:
if lremove:
delattr(obj, attr_name)
else:
setattr(obj, attr_name, None)
return obj
# auxiliary methods (not necessarily bound to class instance)
@staticmethod
def get_norm(varnames, stat_fl, norm_method):
"""
Retrieves normalization instance
:param varnames: list of variabe names
:param stat_fl: statistics JSON-file
:param norm_method: normalization method
:return: normalization instance which can be used to normalize images according to norm_method
"""
method = Postprocess.get_norm.__name__
if not isinstance(varnames, list):
raise ValueError("%{0}: varnames must be a list of variable names.".format(method))
norm_cls = Norm_data(varnames)
try:
with open(stat_fl) as js_file:
norm_cls.check_and_set_norm(json.load(js_file), norm_method)
norm_cls = norm_cls
except Exception as err:
print("%{0}: Could not handle statistics json-file '{1}'.".format(method, stat_fl))
raise err
return norm_cls
@staticmethod
def denorm_images_all_channels(image_sequence, varnames, norm, norm_method="minmax"):
"""
Denormalize data of all image channels
:param image_sequence: list/array [batch, seq, lat, lon, channel] of images
:param varnames: list of variable names whose order matches channel indices
:param norm: normalization instance
:param norm_method: normalization-method (default: 'minmax')
:return: denormalized image data
"""
method = Postprocess.denorm_images_all_channels.__name__
nvars = len(varnames)
image_sequence = np.array(image_sequence)
# sanity checks
if not isinstance(norm, Norm_data):
raise ValueError("%{0}: norm must be a normalization instance.".format(method))
if nvars != np.shape(image_sequence)[-1]:
raise ValueError("%{0}: Number of passed variable names ({1:d}) does not match number of channels ({2:d})"
.format(method, nvars, np.shape(image_sequence)[-1]))
input_images_all_channles_denorm = [Postprocess.denorm_images(image_sequence, norm, {varname: c},
norm_method=norm_method)
for c, varname in enumerate(varnames)]
input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1)
return input_images_denorm
@staticmethod
def denorm_images(input_images, norm, var_dict, norm_method="minmax"):
"""
Denormalize one channel of images
:param input_images: list/array [batch, seq, lat, lon, channel]
:param norm: normalization instance
:param var_dict: dictionary with one key only mapping variable name to channel index, e.g. {"2_t": 0}
:param norm_method: normalization method (default: minmax-normalization)
:return: denormalized image data
"""
method = Postprocess.denorm_images.__name__
# sanity checks
if not isinstance(var_dict, dict):
raise ValueError("%{0}: var_dict is not a dictionary.".format(method))
else:
if len(var_dict.keys()) > 1:
raise ValueError("%{0}: var_dict must contain one key only.".format(method))
varname, channel = *var_dict.keys(), *var_dict.values()
if not isinstance(norm, Norm_data):
raise ValueError("%{0}: norm must be a normalization instance.".format(method))
try:
input_images_denorm = norm.denorm_var(input_images[..., channel], varname, norm_method)
except Exception as err:
print("%{0}: Something went wrong when denormalizing image sequence. Inspect error-message!".format(method))
raise err
return input_images_denorm
@staticmethod
def check_gen_images_stochastic_shape(gen_images_stochastic):
"""
For models with deterministic forecasts, one dimension would be lacking. Therefore, here the array
dimension is expanded by one.
"""
if len(np.array(gen_images_stochastic).shape) == 6:
pass
elif len(np.array(gen_images_stochastic).shape) == 5:
gen_images_stochastic = np.expand_dims(gen_images_stochastic, axis=0)
else:
raise ValueError("Passed gen_images_stochastic is not of the right shape")
return gen_images_stochastic
@staticmethod
def get_persistence(ts, input_dir_pkl):
"""
This function gets the persistence forecast.
'Today's weather will be like yesterday's weather.'
:param ts: list dontaining datetime objects from get_init_times
:param input_dir_pkl: input directory to pickle files
:return time_persistence: list containing the dates and times of the persistence forecast.
:return var_peristence: sequence of images corresponding to these times
"""
ts_persistence = []
year_origin = ts[0].year
for t in range(len(ts)): # Scarlet: this certainly can be made nicer with list comprehension
ts_temp = ts[t] - dt.timedelta(days=1)
ts_persistence.append(ts_temp)
t_persistence_start = ts_persistence[0]
t_persistence_end = ts_persistence[-1]
year_start = t_persistence_start.year
month_start = t_persistence_start.month
month_end = t_persistence_end.month
print("start year:", year_start)
# only one pickle file is needed (all hours during the same month)
if month_start == month_end:
# Open files to search for the indizes of the corresponding time
time_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T'))
# Open file to search for the correspoding meteorological fields
var_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X'))
if year_origin != year_start:
time_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'T'))
var_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'X'))
time_pickle.extend(time_origin_pickle)
var_pickle.extend(var_origin_pickle)
# Retrieve starting index
ind = list(time_pickle).index(np.array(ts_persistence[0]))
var_persistence = np.array(var_pickle)[ind:ind + len(ts_persistence)]
time_persistence = np.array(time_pickle)[ind:ind + len(ts_persistence)].ravel()
# case that we need to derive the data from two pickle files (changing month during the forecast periode)
else:
t_persistence_first_m = [] # should hold dates of the first month
t_persistence_second_m = [] # should hold dates of the second month
for t in range(len(ts)):
m = ts_persistence[t].month
if m == month_start:
t_persistence_first_m.append(ts_persistence[t])
if m == month_end:
t_persistence_second_m.append(ts_persistence[t])
if year_origin == year_start:
# Open files to search for the indizes of the corresponding time
time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T')
time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'T')
# Open file to search for the correspoding meteorological fields
var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X')
var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'X')
if year_origin != year_start:
# Open files to search for the indizes of the corresponding time
time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'T')
time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'T')
# Open file to search for the correspoding meteorological fields
var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'X')
var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'X')
# Retrieve starting index
ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0]))
# print("time_pickle_second:", time_pickle_second)
ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0]))
# append the sequence of the second month to the first month
var_persistence = np.concatenate((var_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)],
var_pickle_second[
ind_second_m:ind_second_m + len(t_persistence_second_m)]),
axis=0)
time_persistence = np.concatenate((time_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)],
time_pickle_second[
ind_second_m:ind_second_m + len(t_persistence_second_m)]),
axis=0).ravel()
# Note: ravel is needed to eliminate the unnecessary dimension (20,1) becomes (20,)
if len(time_persistence.tolist()) == 0:
raise ValueError("The time_persistent is empty!")
if len(var_persistence) == 0:
raise ValueError("The var persistence is empty!")
var_persistence = var_persistence[1:]
time_persistence = time_persistence[1:]
return var_persistence, time_persistence.tolist()
@staticmethod
def load_pickle_for_persistence(input_dir_pkl, year_start, month_start, pkl_type):
"""
There are two types in our workflow: T_[month].pkl where the timestamp is stored,
X_[month].pkl where the variables are stored, e.g. temperature, geopotential and pressure.
This helper function constructs the directory, opens the file to read it, returns the variable.
:param input_dir_pkl: directory where input pickle files are stored
:param year_start: The year for which data is requested as integer
:param month_start: The year for which data is requested as integer
:param pkl_type: Either "X" or "T"
"""
path_to_pickle = os.path.join(input_dir_pkl, str(year_start), pkl_type + "_{:02}.pkl".format(month_start))
with open(path_to_pickle, "rb") as pkl_file:
var = pickle.load(pkl_file)
return var
@staticmethod
def save_ds_to_netcdf(ds, nc_fname, comp_level=5):
"""
Writes xarray dataset into netCDF-file
:param ds: The dataset to be written
:param nc_fname: Path and name of the target netCDF-file
:param comp_level: compression level, must be an integer between 1 and 9 (defualt: 5)
:return: -
"""
method = Postprocess.save_ds_to_netcdf.__name__
# sanity checks
if not isinstance(ds, xr.Dataset):
raise ValueError("%{0}: Argument 'ds' must be a xarray dataset.".format(method))
if not isinstance(comp_level, int):
raise ValueError("%{0}: Argument 'comp_level' must be an integer.".format(method))
else:
if comp_level < 1 or comp_level > 9:
raise ValueError("%{0}: Argument 'comp_level' must be an integer between 1 and 9.".format(method))
if not os.path.isdir(os.path.dirname(nc_fname)):
raise NotADirectoryError("%{0}: The directory to store the netCDf-file does not exist.".format(method))
encode_nc = {key: {"zlib": True, "complevel": comp_level} for key in ds.keys()}
# populate data in netCDF-file (take care for the mode!)
try:
ds.to_netcdf(nc_fname, encoding=encode_nc)
print("%{0}: netCDF-file '{1}' was created successfully.".format(method, nc_fname))
except Exception as err:
print("%{0}: Something unexpected happened when creating netCDF-file '1'".format(method, nc_fname))
raise err
@staticmethod
def append_ds(ds_in: xr.Dataset, ds_preexist: xr.Dataset, varnames: list, dim2append: str):
"""
Append existing datset with subset of dataset based on selected variables
:param ds_in: the input dataset from which variables should be retrieved
:param ds_preexist: the accumulator datsaet to be appended (can be initialized with None)
:param dim2append:
:param varnames: List of variables that should be retrieved from ds_in and that are appended to ds_preexist
:return: appended version of ds_preexist
"""
method = Postprocess.append_ds.__name__
varnames_str = ",".join(varnames)
# sanity checks
if not isinstance(ds_in, xr.Dataset):
raise ValueError("%{0}: ds_in must be a xarray dataset, but is of type {1}".format(method, type(ds_in)))
if not set(varnames).issubset(ds_in.data_vars):
raise ValueError("%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(method,
varnames_str))
if ds_preexist is None:
ds_preexist = ds_in[varnames].copy(deep=True)
return ds_preexist
else:
if not isinstance(ds_preexist, xr.Dataset):
raise ValueError("%{0}: ds_preexist must be a xarray dataset, but is of type {1}"
.format(method, type(ds_preexist)))
if not set(varnames).issubset(ds_preexist.data_vars):
raise ValueError("%{0}: Could not find all varibales ({1}) in pre-existing dataset ds_preexist"
.format(method, varnames_str))
try:
ds_preexist = xr.concat([ds_preexist, ds_in[varnames]], dim2append)
except Exception as err:
print("%{0}: Failed to concat datsets along dimension {1}.".format(method, dim2append))
print(ds_in)
print(ds_preexist)
raise err
return ds_preexist
@staticmethod
def init_metric_ds(fcst_products, eval_metrics, varname, nsamples, nlead_steps):
"""
Initializes dataset for storing evaluation metrics
:param fcst_products: list of forecast products to be evaluated
:param eval_metrics: list of forecast metrics to be calculated
:param varname: name of the variable for which metrics are calculated
:param nsamples: total number of forecast samples
:param nlead_steps: number of forecast steps
:return: eval_metric_ds
"""
eval_metric_dict = dict([("{0}_{1}_{2}".format(varname, *(fcst_prod, eval_met)), (["init_time", "fcst_hour"],
np.full((nsamples, nlead_steps), np.nan)))
for eval_met in eval_metrics for fcst_prod in fcst_products])
init_time_dummy = pd.date_range("1900-01-01 00:00", freq="s", periods=nsamples)
eval_metric_ds = xr.Dataset(eval_metric_dict, coords={"init_time": init_time_dummy, # just a placeholder
"fcst_hour": np.arange(1, nlead_steps+1)})
return eval_metric_ds
@staticmethod
def get_matching_indices(big_array, subset):
"""
Returns the indices where element values match the values in an array
:param big_array: the array to dig through
:param subset: array of values contained in big_array
:return: the desired indices
"""
sorted_keys = np.argsort(big_array)
indexes = sorted_keys[np.searchsorted(big_array, subset, sorter=sorted_keys)]
return indexes
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--results_dir", type=str, default='results',
help="ignored if output_gif_dir is specified")
parser.add_argument("--checkpoint",
help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test',
help='mode for dataset, val or test.')
parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
parser.add_argument("--num_stochastic_samples", type=int, default=1)
parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+", default=("mse", "psnr"),
help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.")
parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0,
help="Channel which is used for evaluation.")
args = parser.parse_args()
print('----------------------------------- Options ------------------------------------')
for k, v in args._get_kwargs():
print(k, "=", v)
print('------------------------------------- End --------------------------------------')
# initialize postprocessing instance
postproc_instance = Postprocess(results_dir=args.results_dir, checkpoint=args.checkpoint, mode="test",
batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples,
gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args,
eval_metrics=args.eval_metrics, channel=args.channel)
# run the postprocessing
postproc_instance.run()
postproc_instance.handle_eval_metrics()
postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel)
postproc_instance.plot_conditional_quantiles()
if __name__ == '__main__':
main()