Skip to content
Snippets Groups Projects
Commit 811bd7d6 authored by lukas leufen's avatar lukas leufen
Browse files

somre renaming and first tests for bootstrap generator

parent 46b3b49f
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #31451 passed
...@@ -10,30 +10,20 @@ import xarray as xr ...@@ -10,30 +10,20 @@ import xarray as xr
import os import os
import re import re
from src import helpers from src import helpers
from typing import List from typing import List, Union
class BootStrapGenerator: class BootStrapGenerator:
def __init__(self, orig_generator, boots, chunksize, bootstrap_path): def __init__(self, orig_generator, number_of_boots, bootstrap_path):
self.orig_generator: DataGenerator = orig_generator self.orig_generator: DataGenerator = orig_generator
self.stations = self.orig_generator.stations self.stations = self.orig_generator.stations
self.variables = self.orig_generator.variables self.variables = self.orig_generator.variables
self.boots = boots self.number_of_boots = number_of_boots
self.chunksize = chunksize
self.bootstrap_path = bootstrap_path self.bootstrap_path = bootstrap_path
self._iterator = 0
def __len__(self): def __len__(self):
""" return len(self.orig_generator) * self.number_of_boots * len(self.variables)
display the number of stations
"""
return len(self.orig_generator)*self.boots*len(self.variables)
def get_labels(self, key):
_, label = self.orig_generator[key]
for _ in range(self.boots):
yield label
def get_generator(self): def get_generator(self):
""" """
...@@ -46,10 +36,10 @@ class BootStrapGenerator: ...@@ -46,10 +36,10 @@ class BootStrapGenerator:
station = self.orig_generator.get_station_key(i) station = self.orig_generator.get_station_key(i)
logging.info(f"station: {station}") logging.info(f"station: {station}")
hist, label = data hist, label = data
shuffled_data = self.load_boot_data(station) shuffled_data = self.load_shuffled_data(station, self.variables)
for var in self.variables: for var in self.variables:
logging.info(f" var: {var}") logging.debug(f" var: {var}")
for boot in range(self.boots): for boot in range(self.number_of_boots):
logging.debug(f"boot: {boot}") logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var)) boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
...@@ -67,23 +57,54 @@ class BootStrapGenerator: ...@@ -67,23 +57,54 @@ class BootStrapGenerator:
for station in self.stations: for station in self.stations:
label = self.orig_generator.get_data_generator(station).get_transposed_label() label = self.orig_generator.get_data_generator(station).get_transposed_label()
for var in self.variables: for var in self.variables:
for boot in range(self.boots): for boot in range(self.number_of_boots):
bootstrap_meta.extend([[var, station]] * len(label)) bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta return bootstrap_meta
def get_orig_prediction(self, path, file_name, prediction_name="CNN"): def get_labels(self, key: Union[str, int]):
"""
Reepats labels for given key by the number of boots and yield it one by one.
:param key: key of station (either station name as string or the position in generator as integer)
:return: yields labels for length of boots
"""
_, label = self.orig_generator[key]
for _ in range(self.number_of_boots):
yield label
def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN"):
"""
Repeats predictions from given file(_name) in path by the number of boots.
:param path: path to file
:param file_name: file name
:param prediction_name: name of the prediction to select from loaded file
:return: yields predictions for length of boots
"""
file = os.path.join(path, file_name) file = os.path.join(path, file_name)
data = xr.open_dataarray(file) data = xr.open_dataarray(file)
for _ in range(self.boots): for _ in range(self.number_of_boots):
yield data.sel(type=prediction_name).squeeze() yield data.sel(type=prediction_name).squeeze()
def load_boot_data(self, station): def load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray:
"""
Load shuffled data from bootstrap path. Data is stored as
'<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g.
'DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc'
:param station:
:param variables:
:return: shuffled data as xarray
"""
files = os.listdir(self.bootstrap_path) files = os.listdir(self.bootstrap_path)
regex = re.compile(rf"{station}_\w*\.nc") regex = self.create_file_regex(station, variables)
file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0]) file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0])
shuffled_data = xr.open_dataarray(file_name, chunks=100) shuffled_data = xr.open_dataarray(file_name, chunks=100)
return shuffled_data return shuffled_data
@staticmethod
def create_file_regex(station, variables):
var_regex = "".join([rf'(_\w+)*_{v}(_\w+)*' for v in sorted(variables)])
regex = re.compile(rf"{station}{var_regex}_shuffled\.nc")
return regex
class BootStraps: class BootStraps:
...@@ -93,7 +114,7 @@ class BootStraps: ...@@ -93,7 +114,7 @@ class BootStraps:
self.bootstrap_path = bootstrap_path self.bootstrap_path = bootstrap_path
self.chunks = self.get_chunk_size() self.chunks = self.get_chunk_size()
self.create_shuffled_data() self.create_shuffled_data()
self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.chunks, self.bootstrap_path) self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.bootstrap_path)
def get_boot_strap_meta(self): def get_boot_strap_meta(self):
return self._boot_strap_generator.get_bootstrap_meta() return self._boot_strap_generator.get_bootstrap_meta()
...@@ -135,7 +156,7 @@ class BootStraps: ...@@ -135,7 +156,7 @@ class BootStraps:
randomly selected variables. If there is a suitable local file for requested window size and number of randomly selected variables. If there is a suitable local file for requested window size and number of
bootstraps, no additional file will be created inside this function. bootstraps, no additional file will be created inside this function.
""" """
logging.info("create shuffled bootstrap data") logging.info("create / check shuffled bootstrap data")
variables_str = '_'.join(sorted(self.data.variables)) variables_str = '_'.join(sorted(self.data.variables))
window = self.data.window_history_size window = self.data.window_history_size
for station in self.data.stations: for station in self.data.stations:
......
from src.data_handling.bootstraps import BootStraps from src.data_handling.bootstraps import BootStraps, BootStrapGenerator
from src.data_handling.data_generator import DataGenerator
import pytest import pytest
import os import os
import numpy as np import numpy as np
import xarray as xr
class TestBootstraps: class TestBootstraps:
...@@ -62,3 +64,56 @@ class TestBootstraps: ...@@ -62,3 +64,56 @@ class TestBootstraps:
def test_create_shuffled_data(self): def test_create_shuffled_data(self):
pass pass
class TestBootstrapGenerator:
@pytest.fixture
def orig_generator(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014)
@pytest.fixture
def boot_gen(self, orig_generator):
path = os.path.join(os.path.dirname(__file__), 'data')
dummy_content = xr.DataArray([1, 2, 3], dims="dummy")
dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_shuffled.nc"))
dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_shuffled.nc"))
return BootStrapGenerator(orig_generator, 20, path)
def test_init(self, orig_generator):
gen = BootStrapGenerator(orig_generator, 20, os.path.join(os.path.dirname(__file__), 'data'))
assert gen.stations == ["DEBW107", "DEBW013"]
assert gen.variables == ["o3", "temp"]
assert gen.number_of_boots == 20
assert gen.bootstrap_path == os.path.join(os.path.dirname(__file__), 'data')
def test_len(self, boot_gen):
assert len(boot_gen) == 80
def test_get_generator(self, boot_gen):
pass
def test_get_bootstrap_meta(self, boot_gen):
pass
def test_get_labels(self, boot_gen):
pass
def test_get_orig_prediction(self, boot_gen):
pass
def test_load_shuffled_data(self, boot_gen):
shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"])
assert isinstance(shuffled_data, xr.DataArray)
assert all(shuffled_data.compute().values == [1, 2, 3])
def test_create_file_regex(self, boot_gen):
regex = boot_gen.create_file_regex("DEBW108", ["o3", "temp", "h2o"])
test_list = ["DEBW108_o3_test23_test_shuffled.nc",
"DEBW107_o3_test23_test_shuffled.nc",
"DEBW108_o3_test23_test.nc",
"DEBW108_h2o_o3_temp_test_shuffled.nc",
"DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]
assert list(filter(regex.search, test_list)) == ["DEBW108_h2o_o3_temp_test_shuffled.nc",
"DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment