From 4800b9e720264989c4456fa21a24c6a13f998e02 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 26 Aug 2020 11:57:48 +0200
Subject: [PATCH] updated naming also for all variables in the run modules

---
 mlair/run_modules/experiment_setup.py        |  4 +--
 mlair/run_modules/pre_processing.py          | 26 ++++++++++----------
 mlair/run_script.py                          |  2 +-
 test/test_run_modules/test_pre_processing.py |  4 +--
 4 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index d7ecbac5..51e710c2 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -221,7 +221,7 @@ class ExperimentSetup(RunEnvironment):
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None,
                  create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
-                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None, **kwargs):
+                 hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -290,7 +290,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("sampling", sampling)
         self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
         self._set_param("transformation", None, scope="preprocessing")
-        self._set_param("data_preparation", data_preparation, default=DefaultDataHandler)
+        self._set_param("data_handler", data_handler, default=DefaultDataHandler)
 
         # target
         self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR)
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index b4185df2..ed972896 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -10,7 +10,7 @@ from typing import Tuple
 import numpy as np
 import pandas as pd
 
-from mlair.data_handler import DataCollection
+from mlair.data_handler import DataCollection, AbstractDataHandler
 from mlair.helpers import TimeTracking
 from mlair.configuration import path_config
 from mlair.helpers.join import EmptyQueryResult
@@ -55,8 +55,8 @@ class PreProcessing(RunEnvironment):
 
     def _run(self):
         stations = self.data_store.get("stations")
-        data_preparation = self.data_store.get("data_preparation")
-        _, valid_stations = self.validate_station(data_preparation, stations, "preprocessing", overwrite_local_data=True)
+        data_handler = self.data_store.get("data_handler")
+        _, valid_stations = self.validate_station(data_handler, stations, "preprocessing", overwrite_local_data=True)
         if len(valid_stations) == 0:
             raise ValueError("Couldn't find any valid data according to given parameters. Abort experiment run.")
         self.data_store.set("stations", valid_stations)
@@ -187,12 +187,12 @@ class PreProcessing(RunEnvironment):
             set_stations = stations[index_list]
         logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}")
         # create set data_collection and store
-        data_preparation = self.data_store.get("data_preparation")
-        collection, valid_stations = self.validate_station(data_preparation, set_stations, set_name)
+        data_handler = self.data_store.get("data_handler")
+        collection, valid_stations = self.validate_station(data_handler, set_stations, set_name)
         self.data_store.set("stations", valid_stations, scope=set_name)
         self.data_store.set("data_collection", collection, scope=set_name)
 
-    def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False):
+    def validate_station(self, data_handler: AbstractDataHandler, set_stations, set_name=None, overwrite_local_data=False):
         """
         Check if all given stations in `all_stations` are valid.
 
@@ -212,14 +212,14 @@ class PreProcessing(RunEnvironment):
         logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}")
         # calculate transformation using train data
         if set_name == "train":
-            self.transformation(data_preparation, set_stations)
+            self.transformation(data_handler, set_stations)
         # start station check
         collection = DataCollection()
         valid_stations = []
-        kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope=set_name)
+        kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name)
         for station in set_stations:
             try:
-                dp = data_preparation.build(station, name_affix=set_name, **kwargs)
+                dp = data_handler.build(station, name_affix=set_name, **kwargs)
                 collection.add(dp)
                 valid_stations.append(station)
             except (AttributeError, EmptyQueryResult):
@@ -228,10 +228,10 @@ class PreProcessing(RunEnvironment):
                      f"{len(set_stations)} valid stations.")
         return collection, valid_stations
 
-    def transformation(self, data_preparation, stations):
-        if hasattr(data_preparation, "transformation"):
-            kwargs = self.data_store.create_args_dict(data_preparation.requirements(), scope="train")
-            transformation_dict = data_preparation.transformation(stations, **kwargs)
+    def transformation(self, data_handler: AbstractDataHandler, stations):
+        if hasattr(data_handler, "transformation"):
+            kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train")
+            transformation_dict = data_handler.transformation(stations, **kwargs)
             if transformation_dict is not None:
                 self.data_store.set("transformation", transformation_dict)
 
diff --git a/mlair/run_script.py b/mlair/run_script.py
index 00a28f68..6dea98ba 100644
--- a/mlair/run_script.py
+++ b/mlair/run_script.py
@@ -27,7 +27,7 @@ def run(stations=None,
         model=None,
         batch_size=None,
         epochs=None,
-        data_preparation=None,
+        data_handler=None,
         **kwargs):
 
     params = inspect.getfullargspec(DefaultWorkflow).args
diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py
index e62c8758..bdb8fdab 100644
--- a/test/test_run_modules/test_pre_processing.py
+++ b/test/test_run_modules/test_pre_processing.py
@@ -28,7 +28,7 @@ class TestPreProcessing:
     def obj_with_exp_setup(self):
         ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
                         statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background",
-                        data_preparation=DefaultDataHandler)
+                        data_handler=DefaultDataHandler)
         pre = object.__new__(PreProcessing)
         super(PreProcessing, pre).__init__()
         yield pre
@@ -90,7 +90,7 @@ class TestPreProcessing:
         pre = obj_with_exp_setup
         caplog.set_level(logging.INFO)
         stations = pre.data_store.get("stations", "general")
-        data_preparation = pre.data_store.get("data_preparation")
+        data_preparation = pre.data_store.get("data_handler")
         collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=name)
         assert isinstance(collection, DataCollection)
         assert len(valid_stations) < len(stations)
-- 
GitLab