From 4800b9e720264989c4456fa21a24c6a13f998e02 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 26 Aug 2020 11:57:48 +0200 Subject: [PATCH] updated naming also for all variables in the run modules --- mlair/run_modules/experiment_setup.py | 4 +-- mlair/run_modules/pre_processing.py | 26 ++++++++++---------- mlair/run_script.py | 2 +- test/test_run_modules/test_pre_processing.py | 4 +-- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index d7ecbac5..51e710c2 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -221,7 +221,7 @@ class ExperimentSetup(RunEnvironment): train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None, **kwargs): + hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, **kwargs): # create run framework super().__init__() @@ -290,7 +290,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") - self._set_param("data_preparation", data_preparation, default=DefaultDataHandler) + self._set_param("data_handler", data_handler, default=DefaultDataHandler) # target self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index b4185df2..ed972896 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -10,7 +10,7 @@ from typing import Tuple import numpy as np import pandas as pd -from mlair.data_handler import DataCollection +from mlair.data_handler import DataCollection, AbstractDataHandler from mlair.helpers import TimeTracking from mlair.configuration import path_config from mlair.helpers.join import EmptyQueryResult @@ -55,8 +55,8 @@ class PreProcessing(RunEnvironment): def _run(self): stations = self.data_store.get("stations") - data_preparation = self.data_store.get("data_preparation") - _, valid_stations = self.validate_station(data_preparation, stations, "preprocessing", overwrite_local_data=True) + data_handler = self.data_store.get("data_handler") + _, valid_stations = self.validate_station(data_handler, stations, "preprocessing", overwrite_local_data=True) if len(valid_stations) == 0: raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.") self.data_store.set("stations", valid_stations) @@ -187,12 +187,12 @@ class PreProcessing(RunEnvironment): set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") # create set data_collection and store - data_preparation = self.data_store.get("data_preparation") - collection, valid_stations = self.validate_station(data_preparation, set_stations, set_name) + data_handler = self.data_store.get("data_handler") + collection, valid_stations = self.validate_station(data_handler, set_stations, set_name) self.data_store.set("stations", valid_stations, scope=set_name) self.data_store.set("data_collection", collection, scope=set_name) - def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False): + def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None, overwrite_local_data=False): """ Check if all given stations in `all_stations` are valid. @@ -212,14 +212,14 @@ class PreProcessing(RunEnvironment): logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") # calculate transformation using train data if set_name == "train": - self.transformation(data_preparation, set_stations) + self.transformation(data_handler, set_stations) # start station check collection = DataCollection() valid_stations = [] - kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name) + kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) for station in set_stations: try: - dp = data_preparation.build(station, name_affix=set_name, **kwargs) + dp = data_handler.build(station, name_affix=set_name, **kwargs) collection.add(dp) valid_stations.append(station) except (AttributeError, EmptyQueryResult): @@ -228,10 +228,10 @@ class PreProcessing(RunEnvironment): f"{len(set_stations)} valid stations.") return collection, valid_stations - def transformation(self, data_preparation, stations): - if hasattr(data_preparation, "transformation"): - kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope="train") - transformation_dict = data_preparation.transformation(stations, **kwargs) + def transformation(self, data_handler: AbstractDataHandler, stations): + if hasattr(data_handler, "transformation"): + kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") + transformation_dict = data_handler.transformation(stations, **kwargs) if transformation_dict is not None: self.data_store.set("transformation", transformation_dict) diff --git a/mlair/run_script.py b/mlair/run_script.py index 00a28f68..6dea98ba 100644 --- a/mlair/run_script.py +++ b/mlair/run_script.py @@ -27,7 +27,7 @@ def run(stations=None, model=None, batch_size=None, epochs=None, - data_preparation=None, + data_handler=None, **kwargs): params = inspect.getfullargspec(DefaultWorkflow).args diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index e62c8758..bdb8fdab 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -28,7 +28,7 @@ class TestPreProcessing: def obj_with_exp_setup(self): ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background", - data_preparation=DefaultDataHandler) + data_handler=DefaultDataHandler) pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre @@ -90,7 +90,7 @@ class TestPreProcessing: pre = obj_with_exp_setup caplog.set_level(logging.INFO) stations = pre.data_store.get("stations", "general") - data_preparation = pre.data_store.get("data_preparation") + data_preparation = pre.data_store.get("data_handler") collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=name) assert isinstance(collection, DataCollection) assert len(valid_stations) < len(stations) -- GitLab