Skip to content
Snippets Groups Projects
Commit d1f9bcb6 authored by Michael Langguth's avatar Michael Langguth
Browse files

Handling of truncated batches (at the end of the test dataset) in the postprocessing.

parent fab232f6
No related branches found
No related tags found
No related merge requests found
Pipeline #68031 passed
...@@ -428,13 +428,13 @@ class Postprocess(TrainModel): ...@@ -428,13 +428,13 @@ class Postprocess(TrainModel):
# denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method) # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method)
gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls, gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls,
norm_method="minmax") norm_method="minmax")
# store data into datset # store data into datset and get number of samples (may differ from batch_size at the end of the test dataset)
times_0, init_times = self.get_init_time(t_starts) times_0, init_times = self.get_init_time(t_starts)
batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times) batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times)
# auxilary list of forecast dimensions nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind)
dims_fcst = list(batch_ds["{0}_ref".format(self.vars_in[0])].dims) batch_ds = batch_ds.isel(init_time=slice(0, nbs))
for i in np.arange(self.batch_size): for i in np.arange(nbs):
# work-around to make use of get_persistence_forecast_per_sample-method # work-around to make use of get_persistence_forecast_per_sample-method
times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime() times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime()
# get persistence forecast for sequences at hand and write to dataset # get persistence forecast for sequences at hand and write to dataset
...@@ -541,8 +541,8 @@ class Postprocess(TrainModel): ...@@ -541,8 +541,8 @@ class Postprocess(TrainModel):
.format(method, ", ".join(misses))) .format(method, ", ".join(misses)))
varname_ref = "{0}_ref".format(varname) varname_ref = "{0}_ref".format(varname)
# reset init-time coordinate of metric_ds in place # reset init-time coordinate of metric_ds in place and get indices for slicing
ind_end = ind_start + min(self.batch_size, len(data_ds["init_time"])) ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch)
init_times_metric = metric_ds["init_time"].values init_times_metric = metric_ds["init_time"].values
init_times_metric[ind_start:ind_end] = data_ds["init_time"] init_times_metric[ind_start:ind_end] = data_ds["init_time"]
metric_ds = metric_ds.assign_coords(init_time=init_times_metric) metric_ds = metric_ds.assign_coords(init_time=init_times_metric)
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment