From d1f9bcb662408c133bf7ab82b2ffbf5ed99a24fc Mon Sep 17 00:00:00 2001
From: Michael <m.langguth@fz-juelich.de>
Date: Tue, 18 May 2021 17:57:28 +0200
Subject: [PATCH] Handling of truncated batches (at the end of the test
 dataset) in the postprocessing.

---
 .../main_scripts/main_visualize_postprocess.py       | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index 00f4991a..68017b45 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -428,13 +428,13 @@ class Postprocess(TrainModel):
             # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method)
             gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls,
                                                                 norm_method="minmax")
-            # store data into datset
+            # store data into datset and get number of samples (may differ from batch_size at the end of the test dataset)
             times_0, init_times = self.get_init_time(t_starts)
             batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times)
-            # auxilary list of forecast dimensions
-            dims_fcst = list(batch_ds["{0}_ref".format(self.vars_in[0])].dims)
+            nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind)
+            batch_ds = batch_ds.isel(init_time=slice(0, nbs))
 
-            for i in np.arange(self.batch_size):
+            for i in np.arange(nbs):
                 # work-around to make use of get_persistence_forecast_per_sample-method
                 times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime()
                 # get persistence forecast for sequences at hand and write to dataset
@@ -541,8 +541,8 @@ class Postprocess(TrainModel):
                                       .format(method, ", ".join(misses)))
 
         varname_ref = "{0}_ref".format(varname)
-        # reset init-time coordinate of metric_ds in place
-        ind_end = ind_start + min(self.batch_size, len(data_ds["init_time"]))
+        # reset init-time coordinate of metric_ds in place and get indices for slicing
+        ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch)
         init_times_metric = metric_ds["init_time"].values
         init_times_metric[ind_start:ind_end] = data_ds["init_time"]
         metric_ds = metric_ds.assign_coords(init_time=init_times_metric)
-- 
GitLab