From ac1e2562f27f59962b450da446e3cc97c583ed0b Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Mon, 18 Oct 2021 10:21:45 +0200 Subject: [PATCH] renamed feature importance bootstrap variables and methods --- mlair/configuration/defaults.py | 11 ++- mlair/data_handler/__init__.py | 2 +- .../{bootstraps.py => input_bootstraps.py} | 6 +- mlair/run_modules/experiment_setup.py | 55 ++++++----- mlair/run_modules/post_processing.py | 95 ++++++++++--------- test/test_data_handler/old_t_bootstraps.py | 14 +-- 6 files changed, 98 insertions(+), 85 deletions(-) rename mlair/data_handler/{bootstraps.py => input_bootstraps.py} (98%) diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 9dc252b8..255f1227 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -44,14 +44,15 @@ DEFAULT_TEST_END = "2017-12-31" DEFAULT_TEST_MIN_LENGTH = 90 DEFAULT_TRAIN_VAL_MIN_LENGTH = 180 DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True +DEFAULT_DO_UNCERTAINTY_ESTIMATE = True DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH = "1m" DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS = True DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS = 1000 -DEFAULT_EVALUATE_BOOTSTRAPS = True -DEFAULT_CREATE_NEW_BOOTSTRAPS = False -DEFAULT_NUMBER_OF_BOOTSTRAPS = 20 -DEFAULT_BOOTSTRAP_TYPE = "singleinput" -DEFAULT_BOOTSTRAP_METHOD = "shuffle" +DEFAULT_EVALUATE_FEATURE_IMPORTANCE = True +DEFAULT_FEATURE_IMPORTANCE_CREATE_NEW_BOOTSTRAPS = False +DEFAULT_FEATURE_IMPORTANCE_N_BOOTS = 20 +DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE = "singleinput" +DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD = "shuffle" DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotPeriodogram"] diff --git a/mlair/data_handler/__init__.py b/mlair/data_handler/__init__.py index 495b6e7c..d1199778 100644 --- a/mlair/data_handler/__init__.py +++ b/mlair/data_handler/__init__.py @@ -9,7 +9,7 @@ __author__ = 'Lukas Leufen, Felix Kleinert' __date__ = '2020-04-17' -from .bootstraps import BootStraps +from .input_bootstraps import Bootstraps from .iterator import KerasIterator, DataCollection from .default_data_handler import DefaultDataHandler from .abstract_data_handler import AbstractDataHandler diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/input_bootstraps.py similarity index 98% rename from mlair/data_handler/bootstraps.py rename to mlair/data_handler/input_bootstraps.py index e0388148..ab4c71f1 100644 --- a/mlair/data_handler/bootstraps.py +++ b/mlair/data_handler/input_bootstraps.py @@ -28,8 +28,8 @@ class BootstrapIterator(Iterator): _position: int = None - def __init__(self, data: "BootStraps", method): - assert isinstance(data, BootStraps) + def __init__(self, data: "Bootstraps", method): + assert isinstance(data, Bootstraps) self._data = data self._dimension = data.bootstrap_dimension self.boot_dim = "boots" @@ -184,7 +184,7 @@ class MeanBootstraps: return np.ones_like(data) * self._mean -class BootStraps(Iterable): +class Bootstraps(Iterable): """ Main class to perform bootstrap operations. diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index aa7a2bc1..68901d33 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -18,12 +18,12 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_WINDOW_DIM, DEFAULT_DIMENSIONS, DEFAULT_TIME_DIM, DEFAULT_INTERPOLATION_METHOD, DEFAULT_INTERPOLATION_LIMIT, \ DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \ DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ - DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ - DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \ + DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_FEATURE_IMPORTANCE, DEFAULT_FEATURE_IMPORTANCE_CREATE_NEW_BOOTSTRAPS, \ + DEFAULT_FEATURE_IMPORTANCE_N_BOOTS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \ DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \ - DEFAULT_BOOTSTRAP_TYPE, DEFAULT_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \ + DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, DEFAULT_OVERWRITE_LAZY_DATA, \ DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, \ - DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS + DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, DEFAULT_DO_UNCERTAINTY_ESTIMATE from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel @@ -214,16 +214,17 @@ class ExperimentSetup(RunEnvironment): sampling: str = None, create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, - extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, - number_of_bootstraps=None, create_new_bootstraps=None, bootstrap_method=None, bootstrap_type=None, + extremes_on_right_tail_only: bool = None, evaluate_feature_importance: bool = None, plot_list=None, + feature_importance_n_boots: int = None, feature_importance_create_new_bootstraps: bool = None, + feature_importance_bootstrap_method=None, feature_importance_bootstrap_type=None, data_path: str = None, batch_path: str = None, login_nodes=None, hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, data_origin: Dict = None, competitors: list = None, competitor_path: str = None, use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None, max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, overwrite_lazy_data: bool = None, uncertainty_estimate_block_length: str = None, - uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int= None, - **kwargs): + uncertainty_estimate_evaluate_competitors: bool = None, uncertainty_estimate_n_boots: int = None, + do_uncertainty_estimate: bool = None, **kwargs): # create run framework super().__init__() @@ -353,22 +354,28 @@ class ExperimentSetup(RunEnvironment): default=DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS) # set post-processing instructions - self._set_param("uncertainty_estimate_block_length", uncertainty_estimate_block_length, - default=DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH) - self._set_param("uncertainty_estimate_evaluate_competitors", uncertainty_estimate_evaluate_competitors, - default=DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS) - self._set_param("uncertainty_estimate_n_boots", uncertainty_estimate_n_boots, - default=DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS) - - self._set_param("evaluate_bootstraps", evaluate_bootstraps, default=DEFAULT_EVALUATE_BOOTSTRAPS, - scope="general.postprocessing") - create_new_bootstraps = max([self.data_store.get("train_model", "general"), - create_new_bootstraps or DEFAULT_CREATE_NEW_BOOTSTRAPS]) - self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing") - self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS, - scope="general.postprocessing") - self._set_param("bootstrap_method", bootstrap_method, default=DEFAULT_BOOTSTRAP_METHOD) - self._set_param("bootstrap_type", bootstrap_type, default=DEFAULT_BOOTSTRAP_TYPE) + self._set_param("do_uncertainty_estimate", do_uncertainty_estimate, + default=DEFAULT_DO_UNCERTAINTY_ESTIMATE, scope="general.postprocessing") + self._set_param("block_length", uncertainty_estimate_block_length, + default=DEFAULT_UNCERTAINTY_ESTIMATE_BLOCK_LENGTH, scope="uncertainty_estimate") + self._set_param("evaluate_competitors", uncertainty_estimate_evaluate_competitors, + default=DEFAULT_UNCERTAINTY_ESTIMATE_EVALUATE_COMPETITORS, scope="uncertainty_estimate") + self._set_param("n_boots", uncertainty_estimate_n_boots, + default=DEFAULT_UNCERTAINTY_ESTIMATE_N_BOOTS, scope="uncertainty_estimate") + + self._set_param("evaluate_feature_importance", evaluate_feature_importance, + default=DEFAULT_EVALUATE_FEATURE_IMPORTANCE, scope="general.postprocessing") + feature_importance_create_new_bootstraps = max([self.data_store.get("train_model", "general"), + feature_importance_create_new_bootstraps or + DEFAULT_FEATURE_IMPORTANCE_CREATE_NEW_BOOTSTRAPS]) + self._set_param("create_new_bootstraps", feature_importance_create_new_bootstraps, scope="feature_importance") + self._set_param("n_boots", feature_importance_n_boots, default=DEFAULT_FEATURE_IMPORTANCE_N_BOOTS, + scope="feature_importance") + self._set_param("bootstrap_method", feature_importance_bootstrap_method, + default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_METHOD, scope="feature_importance") + self._set_param("bootstrap_type", feature_importance_bootstrap_type, + default=DEFAULT_FEATURE_IMPORTANCE_BOOTSTRAP_TYPE, scope="feature_importance") + self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index e580a2b3..5210c12e 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -16,7 +16,7 @@ import pandas as pd import xarray as xr from mlair.configuration import path_config -from mlair.data_handler import BootStraps, KerasIterator +from mlair.data_handler import Bootstraps, KerasIterator from mlair.helpers.datastore import NameNotFoundInDataStore from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel @@ -48,7 +48,7 @@ class PostProcessing(RunEnvironment): * `target_var` [.] * `sampling` [.] * `output_shape` [model] - * `evaluate_bootstraps` [postprocessing] and if enabled: + * `evaluate_feature_importance` [postprocessing] and if enabled: * `create_new_bootstraps` [postprocessing] * `bootstrap_path` [postprocessing] @@ -83,7 +83,7 @@ class PostProcessing(RunEnvironment): self._sampling = self.data_store.get("sampling") self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) self.skill_scores = None - self.bootstrap_skill_scores = None + self.feature_importance_skill_scores = None self.competitor_path = self.data_store.get("competitor_path") self.competitors = to_list(self.data_store.get_default("competitors", default=[])) self.forecast_indicator = "nn" @@ -106,18 +106,19 @@ class PostProcessing(RunEnvironment): self.calculate_test_score() # sample uncertainty - self.estimate_sample_uncertainty() - - # bootstraps - if self.data_store.get("evaluate_bootstraps", "postprocessing"): - with TimeTracking(name="calculate bootstraps"): - create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing") - bootstrap_method = self.data_store.get("bootstrap_method", "postprocessing") - bootstrap_type = self.data_store.get("bootstrap_type", "postprocessing") - self.bootstrap_postprocessing(create_new_bootstraps, bootstrap_type=bootstrap_type, - bootstrap_method=bootstrap_method) - if self.bootstrap_skill_scores is not None: - self.report_bootstrap_results(self.bootstrap_skill_scores) + if self.data_store.get("do_uncertainty_estimate", "postprocessing"): + self.estimate_sample_uncertainty() + + # feature importance bootstraps + if self.data_store.get("evaluate_feature_importance", "postprocessing"): + with TimeTracking(name="calculate feature importance using bootstraps"): + create_new_bootstraps = self.data_store.get("create_new_bootstraps", "feature_importance") + bootstrap_method = self.data_store.get("bootstrap_method", "feature_importance") + bootstrap_type = self.data_store.get("bootstrap_type", "feature_importance") + self.calculate_feature_importance(create_new_bootstraps, bootstrap_type=bootstrap_type, + bootstrap_method=bootstrap_method) + if self.feature_importance_skill_scores is not None: + self.report_feature_importance_results(self.feature_importance_skill_scores) # skill scores and error metrics with TimeTracking(name="calculate skill scores"): @@ -133,9 +134,9 @@ class PostProcessing(RunEnvironment): def estimate_sample_uncertainty(self, separate_ahead=False): #todo: visualize #todo: write results on disk - n_boots = self.data_store.get_default("uncertainty_estimate_n_boots", default=1000) - block_length = self.data_store.get_default("uncertainty_estimate_block_length", default="1m") - evaluate_competitors = self.data_store.get_default("uncertainty_estimate_evaluate_competitors", default=True) + n_boots = self.data_store.get_default("n_boots", default=1000, scope="uncertainty_estimate") + block_length = self.data_store.get_default("block_length", default="1m") + evaluate_competitors = self.data_store.get_default("evaluate_competitors", default=True) block_mse = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, separate_ahead=separate_ahead, block_length=block_length) res = statistics.create_n_bootstrap_realizations(block_mse, self.index_dim, self.model_type_dim, @@ -222,8 +223,8 @@ class PostProcessing(RunEnvironment): continue return xr.concat(competing_predictions, self.model_type_dim) if len(competing_predictions) > 0 else None - def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput", - bootstrap_method="shuffle") -> None: + def calculate_feature_importance(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput", + bootstrap_method="shuffle") -> None: """ Calculate skill scores of bootstrapped data. @@ -237,26 +238,28 @@ class PostProcessing(RunEnvironment): went wrong). """ if _iter == 0: - self.bootstrap_skill_scores = {} + self.feature_importance_skill_scores = {} for boot_type in to_list(bootstrap_type): - self.bootstrap_skill_scores[boot_type] = {} + self.feature_importance_skill_scores[boot_type] = {} for boot_method in to_list(bootstrap_method): try: if create_new_bootstraps: - self.create_bootstrap_forecast(bootstrap_type=boot_type, bootstrap_method=boot_method) - boot_skill_score = self.calculate_bootstrap_skill_scores(bootstrap_type=boot_type, - bootstrap_method=boot_method) - self.bootstrap_skill_scores[boot_type][boot_method] = boot_skill_score + self.create_feature_importance_bootstrap_forecast(bootstrap_type=boot_type, + bootstrap_method=boot_method) + boot_skill_score = self.calculate_feature_importance_skill_scores(bootstrap_type=boot_type, + bootstrap_method=boot_method) + self.feature_importance_skill_scores[boot_type][boot_method] = boot_skill_score except FileNotFoundError: if _iter != 0: - raise RuntimeError(f"bootstrap_postprocessing ({boot_type}, {boot_type}) was called for the 2nd" - f" time. This means, that something internally goes wrong. Please check for " - f"possible errors") - logging.info(f"Could not load all files for bootstrapping ({boot_type}, {boot_type}), restart " - f"bootstrap postprocessing with create_new_bootstraps=True.") - self.bootstrap_postprocessing(True, _iter=1, bootstrap_type=boot_type, bootstrap_method=boot_method) + raise RuntimeError(f"calculate_feature_importance ({boot_type}, {boot_type}) was called for the " + f"2nd time. This means, that something internally goes wrong. Please check " + f"for possible errors") + logging.info(f"Could not load all files for feature importance ({boot_type}, {boot_type}), restart " + f"calculate_feature_importance with create_new_bootstraps=True.") + self.calculate_feature_importance(True, _iter=1, bootstrap_type=boot_type, + bootstrap_method=boot_method) - def create_bootstrap_forecast(self, bootstrap_type, bootstrap_method) -> None: + def create_feature_importance_bootstrap_forecast(self, bootstrap_type, bootstrap_method) -> None: """ Create bootstrapped predictions for all stations and variables. @@ -267,11 +270,11 @@ class PostProcessing(RunEnvironment): with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): # extract all requirements from data store forecast_path = self.data_store.get("forecast_path") - number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") + number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") dims = [self.index_dim, self.ahead_dim, self.model_type_dim] for station in self.test_data: X, Y = None, None - bootstraps = BootStraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type, + bootstraps = Bootstraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type, bootstrap_method=bootstrap_method) for boot in bootstraps: X, Y, (index, dimension) = boot @@ -295,7 +298,7 @@ class PostProcessing(RunEnvironment): labels = xr.DataArray(labels, coords=(*coords, [self.observation_indicator]), dims=dims) labels.to_netcdf(file_name) - def calculate_bootstrap_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]: + def calculate_feature_importance_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]: """ Calculate skill score of bootstrapped variables. @@ -308,10 +311,11 @@ class PostProcessing(RunEnvironment): with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"): # extract all requirements from data store forecast_path = self.data_store.get("forecast_path") - number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") + number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") forecast_file = f"forecasts_norm_%s_test.nc" + reference_name = "orig" - bootstraps = BootStraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type, + bootstraps = Bootstraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type, bootstrap_method=bootstrap_method) number_of_bootstraps = bootstraps.number_of_bootstraps bootstrap_iter = bootstraps.bootstraps() @@ -327,7 +331,7 @@ class PostProcessing(RunEnvironment): # get original forecasts orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps) orig = orig.reshape(shape) - coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"]) + coords = (range(shape[0]), range(1, shape[1] + 1), [reference_name]) orig = xr.DataArray(orig, coords=coords, dims=[self.index_dim, self.ahead_dim, self.model_type_dim]) # calculate skill scores for each variable @@ -343,7 +347,8 @@ class PostProcessing(RunEnvironment): for ahead in range(1, self.window_lead_time + 1): data = boot_data.sel({self.ahead_dim: ahead}) boot_scores.append( - skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig")) + skill_scores.general_skill_score(data, forecast_name=boot_var, + reference_name=reference_name)) skill.loc[boot_var] = np.array(boot_scores) # collect all results in single dictionary @@ -424,8 +429,8 @@ class PostProcessing(RunEnvironment): f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") try: - if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): - for boot_type, boot_data in self.bootstrap_skill_scores.items(): + if (self.feature_importance_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list): + for boot_type, boot_data in self.feature_importance_skill_scores.items(): for boot_method, boot_skill_score in boot_data.items(): try: PlotBootstrapSkillScore(boot_skill_score, plot_folder=self.plot_path, @@ -910,8 +915,8 @@ class PostProcessing(RunEnvironment): avg_error[error_metric] = new_val return avg_error - def report_bootstrap_results(self, results): - """Create a csv file containing all results from bootstrapping.""" + def report_feature_importance_results(self, results): + """Create a csv file containing all results from feature importance.""" report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(report_path) res = [[self.model_type_dim, "method", "station", self.boot_var_dim, self.ahead_dim, "vals"]] @@ -924,7 +929,7 @@ class PostProcessing(RunEnvironment): float(vals.sel({self.boot_var_dim: boot_var, self.ahead_dim: ahead}))]) col_names = res.pop(0) df = pd.DataFrame(res, columns=col_names) - file_name = "bootstrap_skill_score_report_raw.csv" + file_name = "feature_importance_skill_score_report_raw.csv" df.to_csv(os.path.join(report_path, file_name), sep=";") def report_error_metrics(self, errors): diff --git a/test/test_data_handler/old_t_bootstraps.py b/test/test_data_handler/old_t_bootstraps.py index 21c18c6c..e21af9f6 100644 --- a/test/test_data_handler/old_t_bootstraps.py +++ b/test/test_data_handler/old_t_bootstraps.py @@ -7,7 +7,7 @@ import numpy as np import pytest import xarray as xr -from mlair.data_handler.bootstraps import BootStraps +from mlair.data_handler.input_bootstraps import Bootstraps from src.data_handler import DataPrepJoin @@ -171,22 +171,22 @@ class TestBootStraps: @pytest.fixture def bootstrap(self, orig_generator, data_path): - return BootStraps(orig_generator, data_path, 20) + return Bootstraps(orig_generator, data_path, 20) @pytest.fixture @mock.patch("mlair.data_handling.bootstraps.CreateShuffledData", return_value=None) def bootstrap_no_shuffling(self, mock_create_shuffle_data, orig_generator, data_path): shutil.rmtree(data_path) - return BootStraps(orig_generator, data_path, 20) + return Bootstraps(orig_generator, data_path, 20) def test_init_no_shuffling(self, bootstrap_no_shuffling, data_path): - assert isinstance(bootstrap_no_shuffling, BootStraps) + assert isinstance(bootstrap_no_shuffling, Bootstraps) assert bootstrap_no_shuffling.number_of_bootstraps == 20 assert bootstrap_no_shuffling.bootstrap_path == data_path def test_init_with_shuffling(self, orig_generator, data_path, caplog): caplog.set_level(logging.INFO) - BootStraps(orig_generator, data_path, 20) + Bootstraps(orig_generator, data_path, 20) assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data") def test_stations(self, bootstrap_no_shuffling, orig_generator): @@ -213,9 +213,9 @@ class TestBootStraps: @mock.patch("mlair.data_handling.data_generator.DataGenerator._load_pickle_data", side_effect=FileNotFoundError) def test_get_generator_different_generator(self, mock_load_pickle, data_path, orig_generator): - BootStraps(orig_generator, data_path, 20) # to create + Bootstraps(orig_generator, data_path, 20) # to create orig_generator.window_history_size = 4 - bootstrap = BootStraps(orig_generator, data_path, 20) + bootstrap = Bootstraps(orig_generator, data_path, 20) station = bootstrap.stations[0] var = bootstrap.variables[0] var_others = bootstrap.variables[1:] -- GitLab