diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index 9a0346e4f09bbbe75d2c8dd70dac2d26d1b5b146..aa89da0ea66e263d076af9abd578ba125c260bec 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -525,8 +525,8 @@ def calculate_average(data: xr.DataArray, **kwargs) -> xr.DataArray: return data.mean(**kwargs) -def create_n_bootstrap_realizations(data: xr.DataArray, dim_name_time, dim_name_model, n_boots: int = 1000, - dim_name_boots='boots') -> xr.DataArray: +def create_n_bootstrap_realizations(data: xr.DataArray, dim_name_time: str, dim_name_model: str, n_boots: int = 1000, + dim_name_boots: str = 'boots') -> xr.DataArray: """ Create n bootstrap realizations and calculate averages across realizations @@ -546,8 +546,7 @@ def create_n_bootstrap_realizations(data: xr.DataArray, dim_name_time, dim_name_ for boot in range(n_boots): res[boot] = (calculate_average( create_single_bootstrap_realization(data, dim_name_time=dim_name_time), - dim=dim_name_time - )) + dim=dim_name_time, skipna=True)) return res diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index e5e6b77196368a571124efc0844ba7f1bb8ed97f..e7ed04b2f8694e7e4e2c90d215cb042cb33beef8 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -131,10 +131,15 @@ class PostProcessing(RunEnvironment): self.plot() def estimate_sample_uncertainty(self, separate_ahead=False): + #todo: set n_boots + #todo: visualize + #todo: write results on disk block_length = self.data_store.get_default("uncertainty_estimate_block_length", default="1m") evaluate_competitors = self.data_store.get_default("uncertainty_estimate_evaluate_competitors", default=True) block_mse = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, separate_ahead=separate_ahead, block_length=block_length) + res = statistics.create_n_bootstrap_realizations(block_mse, self.index_dim, self.model_type_dim, n_boots=10) + res def calculate_block_mse(self, evaluate_competitors=True, separate_ahead=False, block_length="1m"): path = self.data_store.get("forecast_path")