diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 00f4991a35cf11e5ae4307c30366a9f641c0b3bf..68017b4597080b771b00860ab32cf693c0714d73 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -428,13 +428,13 @@ class Postprocess(TrainModel): # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method) gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls, norm_method="minmax") - # store data into datset + # store data into datset and get number of samples (may differ from batch_size at the end of the test dataset) times_0, init_times = self.get_init_time(t_starts) batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times) - # auxilary list of forecast dimensions - dims_fcst = list(batch_ds["{0}_ref".format(self.vars_in[0])].dims) + nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind) + batch_ds = batch_ds.isel(init_time=slice(0, nbs)) - for i in np.arange(self.batch_size): + for i in np.arange(nbs): # work-around to make use of get_persistence_forecast_per_sample-method times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime() # get persistence forecast for sequences at hand and write to dataset @@ -541,8 +541,8 @@ class Postprocess(TrainModel): .format(method, ", ".join(misses))) varname_ref = "{0}_ref".format(varname) - # reset init-time coordinate of metric_ds in place - ind_end = ind_start + min(self.batch_size, len(data_ds["init_time"])) + # reset init-time coordinate of metric_ds in place and get indices for slicing + ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch) init_times_metric = metric_ds["init_time"].values init_times_metric[ind_start:ind_end] = data_ds["init_time"] metric_ds = metric_ds.assign_coords(init_time=init_times_metric)