Skip to content
Snippets Groups Projects
Commit 6ae330b2 authored by lukas leufen's avatar lukas leufen
Browse files

use dask instead of numpy to create bootstrap data

parent c50e6c55
No related branches found
No related tags found
2 merge requests!59Develop,!52implemented bootstraps
Pipeline #29713 passed
...@@ -6,6 +6,7 @@ from src.run_modules.run_environment import RunEnvironment ...@@ -6,6 +6,7 @@ from src.run_modules.run_environment import RunEnvironment
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
import numpy as np import numpy as np
import logging import logging
import dask.array as da
import xarray as xr import xarray as xr
import os import os
import re import re
...@@ -17,7 +18,7 @@ class BootStraps(RunEnvironment): ...@@ -17,7 +18,7 @@ class BootStraps(RunEnvironment):
super().__init__() super().__init__()
self.test_data: DataGenerator = self.data_store.get("generator", "general.test") self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
self.number_bootstraps = 200 self.number_bootstraps = 100
self.bootstrap_path = self.data_store.get("bootstrap_path", "general") self.bootstrap_path = self.data_store.get("bootstrap_path", "general")
self.create_shuffled_data() self.create_shuffled_data()
...@@ -38,11 +39,12 @@ class BootStraps(RunEnvironment): ...@@ -38,11 +39,12 @@ class BootStraps(RunEnvironment):
file_name = f"{station}_{variables_str}_hist{window}_nboots{nboot}_shuffled.nc" file_name = f"{station}_{variables_str}_hist{window}_nboots{nboot}_shuffled.nc"
file_path = os.path.join(self.bootstrap_path, file_name) file_path = os.path.join(self.bootstrap_path, file_name)
data = data.expand_dims({'boots': range(nboot)}, axis=-1) data = data.expand_dims({'boots': range(nboot)}, axis=-1)
shuffled_variable = np.full(data.shape, np.nan) shuffled_variable = []
for i, var in enumerate(data.coords['variables']): for i, var in enumerate(data.coords['variables']):
single_variable = data.sel(variables=var).values single_variable = data.sel(variables=var).values
shuffled_variable[..., i, :] = self.shuffle_single_variable(single_variable) shuffled_variable.append(self.shuffle_single_variable(single_variable, chunks=(100, *data.shape[1:3], data.shape[-1])))
shuffled_data = xr.DataArray(shuffled_variable, coords=data.coords, dims=data.dims) shuffled_variable_da = da.stack(shuffled_variable, axis=-2, ).rechunk("auto")
shuffled_data = xr.DataArray(shuffled_variable_da, coords=data.coords, dims=data.dims)
shuffled_data.to_netcdf(file_path) shuffled_data.to_netcdf(file_path)
def valid_bootstrap_file(self, station, variables, window): def valid_bootstrap_file(self, station, variables, window):
...@@ -74,10 +76,9 @@ class BootStraps(RunEnvironment): ...@@ -74,10 +76,9 @@ class BootStraps(RunEnvironment):
return False, max_nboot return False, max_nboot
@staticmethod @staticmethod
def shuffle_single_variable(data: np.ndarray) -> np.ndarray: def shuffle_single_variable(data: da.array, chunks) -> np.ndarray:
orig_shape = data.shape size = data.shape
size = orig_shape return da.random.choice(data.reshape(-1,), size=size, chunks=chunks)
return np.random.choice(data.reshape(-1,), size=size)
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment