From e1144c96a5126c121160d373263efd953c7b3c21 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 15 Nov 2019 15:57:08 +0100
Subject: [PATCH] small code refac to use new modular structure

---
 run.py                      | 24 +++++++++--------
 src/data_generator.py       | 10 +++-----
 src/experiment_setup.py     | 51 ++++++++++++++++++++++---------------
 src/helpers.py              | 19 ++++++++------
 src/join.py                 |  1 -
 src/modules.py              | 40 +++++++++++++++++++++++++++++
 test/test_data_generator.py |  2 +-
 7 files changed, 100 insertions(+), 47 deletions(-)

diff --git a/run.py b/run.py
index fb1ef8d9..6b115cf2 100644
--- a/run.py
+++ b/run.py
@@ -8,16 +8,7 @@ from src.experiment_setup import ExperimentSetup
 from src.modules import run, PreProcessing, Training, PostProcessing
 
 
-formatter = "%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]"
-logging.basicConfig(level=logging.INFO, format=formatter)
-
-
-if __name__ == "__main__":
-
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
-                        help="set experiment date as string")
-    args = parser.parse_args()
+def main():
 
     with run():
         exp_setup = ExperimentSetup(args, trainable=True)
@@ -27,3 +18,16 @@ if __name__ == "__main__":
         Training(exp_setup)
 
         PostProcessing(exp_setup)
+
+
+if __name__ == "__main__":
+
+    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
+    logging.basicConfig(format=formatter, level=logging.DEBUG)
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
+                        help="set experiment date as string")
+    args = parser.parse_args()
+
+    main()
diff --git a/src/data_generator.py b/src/data_generator.py
index d067e1e9..3d8a1c7c 100644
--- a/src/data_generator.py
+++ b/src/data_generator.py
@@ -6,8 +6,6 @@ from src import helpers
 from src.data_preparation import DataPrep
 import os
 from typing import Union, List, Tuple
-import decimal
-import numpy as np
 import xarray as xr
 
 
@@ -20,11 +18,11 @@ class DataGenerator(keras.utils.Sequence):
     one entry of integer or string
     """
 
-    def __init__(self, path: str, network: str, stations: Union[str, List[str]], variables: List[str],
+    def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str],
                  interpolate_dim: str, target_dim: str, target_var: str, interpolate_method: str = "linear",
                  limit_nan_fill: int = 1, window_history: int = 7, window_lead_time: int = 4,
                  transform_method: str = "standardise", **kwargs):
-        self.path = os.path.abspath(path)
+        self.data_path = os.path.abspath(data_path)
         self.network = network
         self.stations = helpers.to_list(stations)
         self.variables = variables
@@ -42,7 +40,7 @@ class DataGenerator(keras.utils.Sequence):
         """
         display all class attributes
         """
-        return f"DataGenerator(path='{self.path}', network='{self.network}', stations={self.stations}, " \
+        return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \
                f"variables={self.variables}, interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}'" \
                f", target_var='{self.target_var}', **{self.kwargs})"
 
@@ -96,7 +94,7 @@ class DataGenerator(keras.utils.Sequence):
         :return: preprocessed data as a DataPrep instance
         """
         station = self.get_station_key(key)
-        data = DataPrep(self.path, self.network, station, self.variables, **self.kwargs)
+        data = DataPrep(self.data_path, self.network, station, self.variables, **self.kwargs)
         data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
         data.transform("datetime", method=self.transform_method)
         data.make_history_window(self.interpolate_dim, self.window_history)
diff --git a/src/experiment_setup.py b/src/experiment_setup.py
index 18b5f714..d8cf04ec 100644
--- a/src/experiment_setup.py
+++ b/src/experiment_setup.py
@@ -4,16 +4,17 @@ __date__ = '2019-11-15'
 
 from src import helpers
 import logging
+import argparse
 
 
-class ExperimentSetup:
+class ExperimentSetup(object):
     """
     params:
     trainable: Train new model if true, otherwise try to load existing model
     """
 
     def __init__(self, parser_args, **kwargs):
-        self.args = parser_args
+        self.args = self._set_parser_args(parser_args)
         self.data_path = None
         self.experiment_path = None
         self.experiment_name = None
@@ -22,10 +23,10 @@ class ExperimentSetup:
         self.use_all_stations_on_all_data_sets = None
         self.network = None
         self.var_all_dict = None
-        self.all_stations = None
+        self.stations = None
         self.variables = None
         self.dimensions = None
-        self.dim = None
+        self.interpolate_dim = None
         self.target_dim = None
         self.target_var = None
         self.setup_experiment(**kwargs)
@@ -36,13 +37,24 @@ class ExperimentSetup:
         setattr(self, param, value)
         logging.debug(f"set experiment attribute: {param}={value}")
 
+    @staticmethod
+    def _set_parser_args(args):
+        """
+        Transform args to dict if given as argparse.Namespace
+        :param args:
+        :return:
+        """
+        if isinstance(args, argparse.Namespace):
+            return args.__dict__
+        return args
+
     def setup_experiment(self, **kwargs):
 
         # set data path of this experiment
         self._set_param("data_path", helpers.prepare_host())
 
         # set experiment name
-        exp_date = self.args.experiment_date
+        exp_date = self.args.get("experiment_date")
         exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date)
         self._set_param("experiment_name", exp_name)
         self._set_param("experiment_path", exp_path)
@@ -57,23 +69,20 @@ class ExperimentSetup:
         # use all stations on all data sets (train, val, test)
         self._set_param("use_all_stations_on_all_data_sets", kwargs, default=True)
         self._set_param("network", kwargs, default="AIRBASE")
-        self._set_param("var_all_dict", kwargs, default={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum',
-                                                         'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu',
-                                                         'no2': 'dma8eu', 'cloudcover': 'average_values',
-                                                         'pblheight': 'maximum'})
-        self._set_param("all_stations", kwargs, default=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087',
-                                                         'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004', 'DEBY020',
-                                                         'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073',
-                                                         'DEBY039', 'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040',
-                                                         'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072', 'DEBW042',
-                                                         'DEBW039', 'DEBY001', 'DEBY113', 'DEBY089', 'DEBW024',
-                                                         'DEBW004', 'DEBY037', 'DEBW056', 'DEBW029', 'DEBY068',
-                                                         'DEBW010', 'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084',
-                                                         'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063',
-                                                         'DEBY005', 'DEBW046', 'DEBW103', 'DEBW052', 'DEBW034',
-                                                         'DEBY088', ])
+        self._set_param("var_all_dict", kwargs,
+                        default={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
+                                 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
+                                 'pblheight': 'maximum'})
+        self._set_param("stations", kwargs,
+                        default=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022',
+                                 'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039',
+                                 'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072',
+                                 'DEBW042', 'DEBW039', 'DEBY001', 'DEBY113', 'DEBY089', 'DEBW024', 'DEBW004', 'DEBY037',
+                                 'DEBW056', 'DEBW029', 'DEBY068', 'DEBW010', 'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084',
+                                 'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063', 'DEBY005', 'DEBW046', 'DEBW103',
+                                 'DEBW052', 'DEBW034', 'DEBY088', ])
         self._set_param("variables", kwargs, default=list(self.var_all_dict.keys()))
         self._set_param("dimensions", kwargs, default={'new_index': ['datetime', 'Stations']})
-        self._set_param("dim", kwargs, default='datetime')
+        self._set_param("interpolate_dim", kwargs, default='datetime')
         self._set_param("target_dim", kwargs, default='variables')
         self._set_param("target_var", kwargs, default="o3")
diff --git a/src/helpers.py b/src/helpers.py
index 4fd5cc2b..781a6e6c 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -11,7 +11,6 @@ import numpy as np
 import os
 import time
 import socket
-import sys
 
 
 def to_list(arg):
@@ -23,9 +22,9 @@ def to_list(arg):
 def check_path_and_create(path):
     try:
         os.makedirs(path)
-        logging.info(f"Created path: {path}")
+        logging.debug(f"Created path: {path}")
     except FileExistsError:
-        logging.info(f"Path already exists: {path}")
+        logging.debug(f"Path already exists: {path}")
 
 
 def l_p_loss(power: int):
@@ -134,7 +133,7 @@ class TimeTracking(object):
         return self._duration()
 
 
-def prepare_host():
+def prepare_host(create_new=True):
     hostname = socket.gethostname()
     user = os.getlogin()
     if hostname == 'ZAM144':
@@ -142,7 +141,7 @@ def prepare_host():
     elif hostname == 'zam347':
         path = f'/home/{user}/Data/toar_daily/'
     elif hostname == 'linux-gzsx':
-        path = f'/home/{user}/machinelearningtools'
+        path = f'/home/{user}/machinelearningtools/data/toar_daily/'
     elif (len(hostname) > 2) and (hostname[:2] == 'jr'):
         path = f'/p/project/cjjsc42/{user}/DATA/toar_daily/'
     elif (len(hostname) > 2) and (hostname[:2] == 'jw'):
@@ -151,10 +150,14 @@ def prepare_host():
         logging.error(f"unknown host '{hostname}'")
         raise OSError(f"unknown host '{hostname}'")
     if not os.path.exists(path):
-        logging.error(f"path '{path}' does not exist for host '{hostname}'.")
-        raise NotADirectoryError(f"path '{path}' does not exist for host '{hostname}'.")
+        if create_new:
+            check_path_and_create(path)
+            return path
+        else:
+            logging.error(f"path '{path}' does not exist for host '{hostname}'.")
+            raise NotADirectoryError(f"path '{path}' does not exist for host '{hostname}'.")
     else:
-        logging.info(f"set path to: {path}")
+        logging.debug(f"set path to: {path}")
         return path
 
 
diff --git a/src/join.py b/src/join.py
index a8b8edc7..2b13dcf4 100644
--- a/src/join.py
+++ b/src/join.py
@@ -11,7 +11,6 @@ from typing import Iterator, Union, List
 from src import helpers
 
 join_url_base = 'https://join.fz-juelich.de/services/rest/surfacedata/'
-logging.basicConfig(level=logging.INFO)
 
 
 def download_join(station_name: Union[str, List[str]], statvar: dict) -> [pd.DataFrame, pd.DataFrame]:
diff --git a/src/modules.py b/src/modules.py
index 04b7f849..0e03a352 100644
--- a/src/modules.py
+++ b/src/modules.py
@@ -1,6 +1,9 @@
 from src.helpers import TimeTracking
 import logging
 import time
+from src.data_generator import DataGenerator
+from src.experiment_setup import ExperimentSetup
+import argparse
 
 
 class run(object):
@@ -32,6 +35,29 @@ class PreProcessing(run):
     def __init__(self, setup):
         super().__init__()
         self.setup = setup
+        self.kwargs = None
+        self._run()
+
+    def _run(self):
+        self.kwargs = {'start': '1997-01-01', 'end': '2017-12-31', 'limit': 1, 'window_history': 13,
+                            'window_lead_time': 3, 'method': 'linear',
+              'statistics_per_var': self.setup.var_all_dict, }
+        self.check_valid_stations()
+
+    def check_valid_stations(self):
+        t = TimeTracking
+        logging.debug("check valid stations started")
+        window_lead_time = self.kwargs.get("window_lead_time", None)
+        valid_stations = []
+        for s in self.setup.stations:
+            valid = False
+            args = self.setup.__dict__
+            args["stations"] = s
+
+            h = DataGenerator(**args, **self.kwargs)
+            da_it = h.get_data_generator(s)
+            print('hi')
+
 
 
 class Training(run):
@@ -46,3 +72,17 @@ class PostProcessing(run):
     def __init__(self, setup):
         super().__init__()
         self.setup = setup
+
+
+if __name__ == "__main__":
+
+    formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
+    logging.basicConfig(format=formatter, level=logging.DEBUG)
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
+                        help="set experiment date as string")
+    args = parser.parse_args()
+    with run():
+        setup = ExperimentSetup(args, test=True)
+        PreProcessing(setup)
diff --git a/test/test_data_generator.py b/test/test_data_generator.py
index ab6233f3..7c745782 100644
--- a/test/test_data_generator.py
+++ b/test/test_data_generator.py
@@ -17,7 +17,7 @@ class TestDataGenerator:
                              'datetime', 'variables', 'o3')
 
     def test_init(self, gen):
-        assert gen.path == os.path.join(os.path.dirname(__file__), 'data')
+        assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data')
         assert gen.network == 'UBA'
         assert gen.stations == ['DEBW107']
         assert gen.variables == ['o3', 'temp']
-- 
GitLab