Skip to content
Snippets Groups Projects
Commit 811bd7d6 authored by lukas leufen's avatar lukas leufen
Browse files

somre renaming and first tests for bootstrap generator

parent 46b3b49f
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #31451 passed
......@@ -10,30 +10,20 @@ import xarray as xr
import os
import re
from src import helpers
from typing import List
from typing import List, Union
class BootStrapGenerator:
def __init__(self, orig_generator, boots, chunksize, bootstrap_path):
def __init__(self, orig_generator, number_of_boots, bootstrap_path):
self.orig_generator: DataGenerator = orig_generator
self.stations = self.orig_generator.stations
self.variables = self.orig_generator.variables
self.boots = boots
self.chunksize = chunksize
self.number_of_boots = number_of_boots
self.bootstrap_path = bootstrap_path
self._iterator = 0
def __len__(self):
"""
display the number of stations
"""
return len(self.orig_generator)*self.boots*len(self.variables)
def get_labels(self, key):
_, label = self.orig_generator[key]
for _ in range(self.boots):
yield label
return len(self.orig_generator) * self.number_of_boots * len(self.variables)
def get_generator(self):
"""
......@@ -46,10 +36,10 @@ class BootStrapGenerator:
station = self.orig_generator.get_station_key(i)
logging.info(f"station: {station}")
hist, label = data
shuffled_data = self.load_boot_data(station)
shuffled_data = self.load_shuffled_data(station, self.variables)
for var in self.variables:
logging.info(f" var: {var}")
for boot in range(self.boots):
logging.debug(f" var: {var}")
for boot in range(self.number_of_boots):
logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
......@@ -67,23 +57,54 @@ class BootStrapGenerator:
for station in self.stations:
label = self.orig_generator.get_data_generator(station).get_transposed_label()
for var in self.variables:
for boot in range(self.boots):
for boot in range(self.number_of_boots):
bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta
def get_orig_prediction(self, path, file_name, prediction_name="CNN"):
def get_labels(self, key: Union[str, int]):
"""
Reepats labels for given key by the number of boots and yield it one by one.
:param key: key of station (either station name as string or the position in generator as integer)
:return: yields labels for length of boots
"""
_, label = self.orig_generator[key]
for _ in range(self.number_of_boots):
yield label
def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN"):
"""
Repeats predictions from given file(_name) in path by the number of boots.
:param path: path to file
:param file_name: file name
:param prediction_name: name of the prediction to select from loaded file
:return: yields predictions for length of boots
"""
file = os.path.join(path, file_name)
data = xr.open_dataarray(file)
for _ in range(self.boots):
for _ in range(self.number_of_boots):
yield data.sel(type=prediction_name).squeeze()
def load_boot_data(self, station):
def load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray:
"""
Load shuffled data from bootstrap path. Data is stored as
'<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g.
'DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc'
:param station:
:param variables:
:return: shuffled data as xarray
"""
files = os.listdir(self.bootstrap_path)
regex = re.compile(rf"{station}_\w*\.nc")
regex = self.create_file_regex(station, variables)
file_name = os.path.join(self.bootstrap_path, list(filter(regex.search, files))[0])
shuffled_data = xr.open_dataarray(file_name, chunks=100)
return shuffled_data
@staticmethod
def create_file_regex(station, variables):
var_regex = "".join([rf'(_\w+)*_{v}(_\w+)*' for v in sorted(variables)])
regex = re.compile(rf"{station}{var_regex}_shuffled\.nc")
return regex
class BootStraps:
......@@ -93,7 +114,7 @@ class BootStraps:
self.bootstrap_path = bootstrap_path
self.chunks = self.get_chunk_size()
self.create_shuffled_data()
self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.chunks, self.bootstrap_path)
self._boot_strap_generator = BootStrapGenerator(self.data, self.number_bootstraps, self.bootstrap_path)
def get_boot_strap_meta(self):
return self._boot_strap_generator.get_bootstrap_meta()
......@@ -135,7 +156,7 @@ class BootStraps:
randomly selected variables. If there is a suitable local file for requested window size and number of
bootstraps, no additional file will be created inside this function.
"""
logging.info("create shuffled bootstrap data")
logging.info("create / check shuffled bootstrap data")
variables_str = '_'.join(sorted(self.data.variables))
window = self.data.window_history_size
for station in self.data.stations:
......
from src.data_handling.bootstraps import BootStraps
from src.data_handling.bootstraps import BootStraps, BootStrapGenerator
from src.data_handling.data_generator import DataGenerator
import pytest
import os
import numpy as np
import xarray as xr
class TestBootstraps:
......@@ -61,4 +63,57 @@ class TestBootstraps:
assert set(np.unique(res)).issubset({1, 2, 3})
def test_create_shuffled_data(self):
pass
\ No newline at end of file
pass
class TestBootstrapGenerator:
@pytest.fixture
def orig_generator(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014)
@pytest.fixture
def boot_gen(self, orig_generator):
path = os.path.join(os.path.dirname(__file__), 'data')
dummy_content = xr.DataArray([1, 2, 3], dims="dummy")
dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_shuffled.nc"))
dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_shuffled.nc"))
return BootStrapGenerator(orig_generator, 20, path)
def test_init(self, orig_generator):
gen = BootStrapGenerator(orig_generator, 20, os.path.join(os.path.dirname(__file__), 'data'))
assert gen.stations == ["DEBW107", "DEBW013"]
assert gen.variables == ["o3", "temp"]
assert gen.number_of_boots == 20
assert gen.bootstrap_path == os.path.join(os.path.dirname(__file__), 'data')
def test_len(self, boot_gen):
assert len(boot_gen) == 80
def test_get_generator(self, boot_gen):
pass
def test_get_bootstrap_meta(self, boot_gen):
pass
def test_get_labels(self, boot_gen):
pass
def test_get_orig_prediction(self, boot_gen):
pass
def test_load_shuffled_data(self, boot_gen):
shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"])
assert isinstance(shuffled_data, xr.DataArray)
assert all(shuffled_data.compute().values == [1, 2, 3])
def test_create_file_regex(self, boot_gen):
regex = boot_gen.create_file_regex("DEBW108", ["o3", "temp", "h2o"])
test_list = ["DEBW108_o3_test23_test_shuffled.nc",
"DEBW107_o3_test23_test_shuffled.nc",
"DEBW108_o3_test23_test.nc",
"DEBW108_h2o_o3_temp_test_shuffled.nc",
"DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]
assert list(filter(regex.search, test_list)) == ["DEBW108_h2o_o3_temp_test_shuffled.nc",
"DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment