Skip to content
Snippets Groups Projects
Commit 6906ca5d authored by lukas leufen's avatar lukas leufen
Browse files

create bootstrap forecasts

parent 44a70581
No related branches found
No related tags found
2 merge requests!59Develop,!52implemented bootstraps
Pipeline #30492 passed
...@@ -31,16 +31,7 @@ class BootStrapGenerator: ...@@ -31,16 +31,7 @@ class BootStrapGenerator:
""" """
return len(self.orig_generator)*self.boots*len(self.variables) return len(self.orig_generator)*self.boots*len(self.variables)
# def __iter__(self): def get_generator(self):
# """
# Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute
# `_iterator` to 0.
# :return:
# """
# self._iterator = 0
# return self
def __iter__(self):
""" """
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator. the history and label data of this generator.
...@@ -69,7 +60,6 @@ class BootStrapGenerator: ...@@ -69,7 +60,6 @@ class BootStrapGenerator:
files = os.listdir(self.bootstrap_path) files = os.listdir(self.bootstrap_path)
regex = re.compile(rf"{station}_\w*\.nc") regex = re.compile(rf"{station}_\w*\.nc")
file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0]) file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0])
# shuffled_data = xr.open_dataarray(file_name, chunks=self.chunksize)
shuffled_data = xr.open_dataarray(file_name, chunks=100) shuffled_data = xr.open_dataarray(file_name, chunks=100)
return shuffled_data return shuffled_data
...@@ -79,10 +69,8 @@ class BootStraps(RunEnvironment): ...@@ -79,10 +69,8 @@ class BootStraps(RunEnvironment):
def __init__(self, data, bootstrap_path, number_bootstraps=10): def __init__(self, data, bootstrap_path, number_bootstraps=10):
super().__init__() super().__init__()
# self.data: DataGenerator = self.data_store.get("generator", "general.test")
self.data: DataGenerator = data self.data: DataGenerator = data
self.number_bootstraps = number_bootstraps self.number_bootstraps = number_bootstraps
# self.bootstrap_path = self.data_store.get("bootstrap_path", "general")
self.bootstrap_path = bootstrap_path self.bootstrap_path = bootstrap_path
self.chunks = self.get_chunk_size() self.chunks = self.get_chunk_size()
self.create_shuffled_data() self.create_shuffled_data()
...@@ -92,7 +80,10 @@ class BootStraps(RunEnvironment): ...@@ -92,7 +80,10 @@ class BootStraps(RunEnvironment):
return self._boot_strap_generator.bootstrap_meta return self._boot_strap_generator.bootstrap_meta
def boot_strap_generator(self): def boot_strap_generator(self):
return self._boot_strap_generator return self._boot_strap_generator.get_generator()
def get_boot_strap_generator_length(self):
return self._boot_strap_generator.__len__()
def get_chunk_size(self): def get_chunk_size(self):
hist, _ = self.data[0] hist, _ = self.data[0]
...@@ -104,6 +95,7 @@ class BootStraps(RunEnvironment): ...@@ -104,6 +95,7 @@ class BootStraps(RunEnvironment):
randomly selected variables. If there is a suitable local file for requested window size and number of randomly selected variables. If there is a suitable local file for requested window size and number of
bootstraps, no additional file will be created inside this function. bootstraps, no additional file will be created inside this function.
""" """
logging.info("create shuffled bootstrap data")
variables_str = '_'.join(sorted(self.data.variables)) variables_str = '_'.join(sorted(self.data.variables))
window = self.data.window_history_size window = self.data.window_history_size
for station in self.data.stations: for station in self.data.stations:
......
...@@ -13,6 +13,7 @@ import xarray as xr ...@@ -13,6 +13,7 @@ import xarray as xr
from src import statistics from src import statistics
from src.data_handling.data_distributor import Distributor from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
from src.data_handling.bootstraps import BootStraps
from src.datastore import NameNotFoundInDataStore from src.datastore import NameNotFoundInDataStore
from src.helpers import TimeTracking from src.helpers import TimeTracking
from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.linear_model import OrdinaryLeastSquaredModel
...@@ -50,6 +51,25 @@ class PostProcessing(RunEnvironment): ...@@ -50,6 +51,25 @@ class PostProcessing(RunEnvironment):
"skip make_prediction() whenever it is possible to save time.") "skip make_prediction() whenever it is possible to save time.")
self.skill_scores = self.calculate_skill_scores() self.skill_scores = self.calculate_skill_scores()
self.plot() self.plot()
self.create_boot_straps()
def create_boot_straps(self):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
with TimeTracking(name="boot predictions"):
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length())
bootstrap_meta = np.array(bootstraps.get_boot_strap_meta())
length = sum(bootstrap_meta == bootstrap_meta[0])
variables = np.unique(bootstrap_meta)
for boot in variables:
ind = (bootstrap_meta == boot)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"])
file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc")
tmp.to_netcdf(file_name)
def _load_model(self): def _load_model(self):
try: try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment