From a40504e943729ff95a0a8231c7146e4e700ca51f Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 15 Nov 2019 09:29:01 +0100
Subject: [PATCH] set variables, stations, ... in experiment setup

---
 run.py | 49 ++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 40 insertions(+), 9 deletions(-)

diff --git a/run.py b/run.py
index c9523595..6b5c367d 100644
--- a/run.py
+++ b/run.py
@@ -43,20 +43,30 @@ class ExperimentSetup:
     trainable: Train new model if true, otherwise try to load existing model
     """
 
-    def __init__(self, trainable=False):
+    def __init__(self, **kwargs):
         self.data_path = None
         self.experiment_path = None
         self.experiment_name = None
         self.trainable = None
         self.fraction_of_train = None
         self.use_all_stations_on_all_data_sets = None
-        self.setup_experiment(trainable)
-
-    def _set_param(self, param, value):
+        self.network = None
+        self.var_all_dict = None
+        self.all_stations = None
+        self.variables = None
+        self.dimensions = None
+        self.dim = None
+        self.target_dim = None
+        self.target_var = None
+        self.setup_experiment(**kwargs)
+
+    def _set_param(self, param, value, default=None):
+        if default is not None:
+            value = value.get(param, default)
         setattr(self, param, value)
-        logging.debug(f"set attribute: {param}={value}")
+        logging.info(f"set experiment attribute: {param}={value}")
 
-    def setup_experiment(self, trainable):
+    def setup_experiment(self, **kwargs):
 
         # set data path of this experiment
         self._set_param("data_path", helpers.prepare_host())
@@ -69,13 +79,34 @@ class ExperimentSetup:
         helpers.check_path_and_create(self.experiment_path)
 
         # set if model is trainable
-        self._set_param("trainable", trainable)
+        self._set_param("trainable", kwargs, default=True)
 
         # set fraction of train
-        self._set_param("fraction_of_train", 0.8)
+        self._set_param("fraction_of_train", kwargs, default=0.8)
 
         # use all stations on all data sets (train, val, test)
-        self._set_param("use_all_stations_on_all_data_sets", True)
+        self._set_param("use_all_stations_on_all_data_sets", kwargs, default=True)
+        self._set_param("network", kwargs, default="AIRBASE")
+        self._set_param("var_all_dict", kwargs, default={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum',
+                                                         'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu',
+                                                         'no2': 'dma8eu', 'cloudcover': 'average_values',
+                                                         'pblheight': 'maximum'})
+        self._set_param("all_stations", kwargs, default=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087',
+                                                         'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004', 'DEBY020',
+                                                         'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073',
+                                                         'DEBY039', 'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040',
+                                                         'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072', 'DEBW042',
+                                                         'DEBW039', 'DEBY001', 'DEBY113', 'DEBY089', 'DEBW024',
+                                                         'DEBW004', 'DEBY037', 'DEBW056', 'DEBW029', 'DEBY068',
+                                                         'DEBW010', 'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084',
+                                                         'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063',
+                                                         'DEBY005', 'DEBW046', 'DEBW103', 'DEBW052', 'DEBW034',
+                                                         'DEBY088', ])
+        self._set_param("variables", kwargs, default=list(self.var_all_dict.keys()))
+        self._set_param("dimensions", kwargs, default={'new_index': ['datetime', 'Stations']})
+        self._set_param("dim", kwargs, default='datetime')
+        self._set_param("target_dim", kwargs, default='variables')
+        self._set_param("target_var", kwargs, default="o3")
 
 
 class PreProcessing(run):
-- 
GitLab