From 66aad8c7cd9edab2ebdef4741d7b85358d32d3a4 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 15 Oct 2021 13:56:22 +0200
Subject: [PATCH] can calculate block mse

---
 mlair/run_modules/post_processing.py | 64 ++++++++++++++++++++++++++++
 1 file changed, 64 insertions(+)

diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 72362c87..3cff0447 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -104,6 +104,9 @@ class PostProcessing(RunEnvironment):
         # calculate error metrics on test data
         self.calculate_test_score()
 
+        # sample uncertainty
+        self.estimate_sample_uncertainty()
+
         # bootstraps
         if self.data_store.get("evaluate_bootstraps", "postprocessing"):
             with TimeTracking(name="calculate bootstraps"):
@@ -126,6 +129,67 @@ class PostProcessing(RunEnvironment):
         # plotting
         self.plot()
 
+    def estimate_sample_uncertainty(self, evaluate_competitors=True, separate_ahead=False, block_length="1m"):
+        # block_length = self.data_store.get("sample_uncertainty_block_length")
+        block_mse = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, separate_ahead=separate_ahead,
+                                             block_length=block_length)
+
+    def calculate_block_mse(self, evaluate_competitors=True, separate_ahead=False, block_length="1m"):
+        path = self.data_store.get("forecast_path")
+        all_stations = self.data_store.get("stations")
+        start = self.data_store.get("start", "test")
+        end = self.data_store.get("end", "test")
+        index_dim = "index"
+        coll_dim = "station"
+        collector = []
+        for station in all_stations:
+            external_data = self._get_external_data(station, path)  # test data
+
+            # test errors
+            if external_data is not None:
+                pass
+
+            # load competitors
+            if evaluate_competitors is True:
+                competitor = self.load_competitors(station)
+                combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim)
+            else:
+                combined = external_data
+
+            #
+            if combined is None:
+                continue
+            else:
+                combined = self.create_full_time_dim(combined, index_dim, self._sampling, start, end)
+                errors = self.create_error_array(combined)  # get squared errors
+                mse = errors.resample(indexer={index_dim: block_length}).mean(skipna=True)
+                collector.append(mse.assign_coords({coll_dim: station}))
+        mse_blocks = xr.concat(collector, dim=coll_dim).mean(dim=coll_dim, skipna=True)
+        if separate_ahead is False:
+            mse_blocks = mse_blocks.mean(dim=self.ahead_dim, skipna=True)
+        return mse_blocks
+
+    def create_error_array(self, data):
+        """
+        Calculate squared error of all given time series in relation to observation.
+        """
+        errors = data.drop_sel({self.model_type_dim: self.observation_indicator})
+        errors1 = errors - data.sel({self.model_type_dim: self.observation_indicator})
+        errors2 = errors1 ** 2
+        return errors2
+
+    @staticmethod
+    def create_full_time_dim(data, dim, sampling, start, end):
+        """Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
+        start_data = data.coords[dim].values[0]
+        freq = {"daily": "1D", "hourly": "1H"}.get(sampling)
+        datetime_index = pd.DataFrame(index=pd.date_range(start, end, freq=freq))
+        t = data.sel({dim: start_data}, drop=True)
+        res = xr.DataArray(coords=[datetime_index.index, *[t.coords[c] for c in t.coords]], dims=[dim, *t.coords])
+        res = res.transpose(*data.dims)
+        res.loc[data.coords] = data
+        return res
+
     def load_competitors(self, station_name: str) -> xr.DataArray:
         """
         Load all requested and available competitors for a given station. Forecasts must be available in the competitor
-- 
GitLab