diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py index 2270c1ee2abf8b17913e6017181cffcde17bd923..7bdd2ce210c126bf47dcf02c28f4efaacf789457 100644 --- a/mlair/model_modules/convolutional_networks.py +++ b/mlair/model_modules/convolutional_networks.py @@ -75,7 +75,7 @@ class CNNfromConfig(AbstractModelClass): # apply to model self.set_model() self.set_compile_options() - self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) + self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) def set_model(self): x_input = keras.layers.Input(shape=self._input_shape) diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index bf09ac6fc8c63bcfc31024dafb550c84e0ff5df4..b51a3f9c76ace4feef72e4c96945b463fb69a673 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -224,10 +224,13 @@ class ModelSetup(RunEnvironment): if v is None: continue if isinstance(v, list): - if isinstance(v[0], dict): - v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]] + if len(v) > 0: + if isinstance(v[0], dict): + v = ["{" + vi + "}" for vi in [",".join(f"{_f(str(uk))}:{_f(str(uv))}" for uk, uv in d.items()) for d in v]] + else: + v = ",".join(_f(str(u)) for u in v) else: - v = ",".join(_f(str(u)) for u in v) + v = "[]" if "<" in str(v): v = _f(str(v)) df.loc[k] = str(v) diff --git a/mlair/run_modules/run_environment.py b/mlair/run_modules/run_environment.py index 191ee30f0485c6abc42a8d718612b7b26feb221a..2bc81750bf86d60e0c59bbf0fef68ae9c29138c9 100644 --- a/mlair/run_modules/run_environment.py +++ b/mlair/run_modules/run_environment.py @@ -113,21 +113,22 @@ class RunEnvironment(object): not as inheritance from this class, log file is copied and data store is cleared. """ if not self.del_by_exit: - self.time.stop() - try: - logging.info(f"{self._name} finished after {self.time}") - except NameError: - pass - self.del_by_exit = True - # copy log file and clear data store only if called as base class and not as super class - if self.__class__.__name__ == "RunEnvironment": + if hasattr(self, "time"): + self.time.stop() try: - self.__plot_tracking() - self.__save_tracking() - self.__move_log_file() - except (FileNotFoundError, NameError): + logging.info(f"{self._name} finished after {self.time}") + except NameError: pass - self.data_store.clear_data_store() + self.del_by_exit = True + # copy log file and clear data store only if called as base class and not as super class + if self.__class__.__name__ == "RunEnvironment": + try: + self.__plot_tracking() + self.__save_tracking() + self.__move_log_file() + except (FileNotFoundError, NameError): + pass + self.data_store.clear_data_store() def __enter__(self): """Enter run environment.""" diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index 743900bb4b0eab96a35e0f263d7525fcf060b597..6646e1a4795756edd1792ef91f535132e8cde61d 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -9,8 +9,6 @@ from mlair.helpers import PyTestRegex from mlair.run_modules.experiment_setup import ExperimentSetup from mlair.run_modules.pre_processing import PreProcessing from mlair.run_modules.run_environment import RunEnvironment -import pandas as pd -import numpy as np import multiprocessing @@ -30,24 +28,25 @@ class TestPreProcessing: @pytest.fixture def obj_with_exp_setup(self): - ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW99X'], + ExperimentSetup(stations=['DEBW107', 'DEBW013', 'DEBW087', 'DEBW99X'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background", - data_handler=DefaultDataHandler) + data_origin={'o3': 'UBA', 'temp': 'UBA'}, data_handler=DefaultDataHandler) pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre RunEnvironment().__del__() def test_init(self, caplog): - ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + ExperimentSetup(stations=['DEBW107', 'DEBW013', 'DEBW087'], + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + data_origin={'o3': 'UBA', 'temp': 'UBA'}) caplog.clear() caplog.set_level(logging.INFO) with PreProcessing(): assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)') - assert caplog.record_tuples[-6] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 ' - r'station\(s\). Found 5/5 valid stations.')) + assert caplog.record_tuples[-6] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 3 ' + r'station\(s\). Found 3/3 valid stations.')) assert caplog.record_tuples[-5] == ('root', 20, "use serial create_info_df (train)") assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (val)") assert caplog.record_tuples[-3] == ('root', 20, "use serial create_info_df (test)") @@ -77,23 +76,23 @@ class TestPreProcessing: caplog.set_level(logging.DEBUG) obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") in caplog.record_tuples + assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBW013']") in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection) with pytest.raises(NameNotFoundInScope): data_store.get("data_collection", "general") - assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"] + assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBW013"] def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW99X']" + message = "Awesome stations (len=4): ['DEBW107', 'DEBW013', 'DEBW087', 'DEBW99X']" assert ('root', 10, message) in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection) with pytest.raises(NameNotFoundInScope): data_store.get("data_collection", "general") - assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] + assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBW013', 'DEBW087'] @pytest.mark.parametrize("name", (None, "tester")) def test_validate_station_serial(self, caplog, obj_with_exp_setup, name): @@ -108,8 +107,8 @@ class TestPreProcessing: expected = "check valid stations started" + ' (%s)' % (name if name else 'all') assert caplog.record_tuples[0] == ('root', 20, expected) assert caplog.record_tuples[1] == ('root', 20, "use serial validate station approach") - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' - r'station\(s\). Found 5/6 valid stations.')) + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 4 ' + r'station\(s\). Found 3/4 valid stations.')) @mock.patch("psutil.cpu_count", return_value=3) @mock.patch("multiprocessing.Pool", return_value=multiprocessing.Pool(3)) @@ -126,8 +125,8 @@ class TestPreProcessing: assert caplog.record_tuples[0] == ('root', 20, "check valid stations started (all)") assert caplog.record_tuples[1] == ('root', 20, "use parallel validate station approach") assert caplog.record_tuples[2] == ('root', 20, "running 3 processes in parallel") - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' - r'station\(s\). Found 5/6 valid stations.')) + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 4 ' + r'station\(s\). Found 3/4 valid stations.')) def test_split_set_indices(self, obj_super_init): dummy_list = list(range(0, 15)) @@ -146,24 +145,3 @@ class TestPreProcessing: class data_preparation_no_trans: pass assert pre.transformation(data_preparation_no_trans, stations) is None - - # @pytest.fixture - # def dummy_df(self): - # data_dict = {'station_name': {'DEBW013': 'Stuttgart Bad Cannstatt', 'DEBW076': 'Baden-Baden', - # 'DEBW087': 'Schwäbische_Alb', 'DEBW107': 'Tübingen', - # 'DEBY081': 'Garmisch-Partenkirchen/Kreuzeckbahnstraße', '# Stations': np.nan, - # '# Samples': np.nan}, - # 'station_lon': {'DEBW013': 9.2297, 'DEBW076': 8.2202, 'DEBW087': 9.2076, 'DEBW107': 9.0512, - # 'DEBY081': 11.0631, '# Stations': np.nan, '# Samples': np.nan}, - # 'station_lat': {'DEBW013': 48.8088, 'DEBW076': 48.7731, 'DEBW087': 48.3458, 'DEBW107': 48.5077, - # 'DEBY081': 47.4764, '# Stations': np.nan, '# Samples': np.nan}, - # 'station_alt': {'DEBW013': 235.0, 'DEBW076': 148.0, 'DEBW087': 798.0, 'DEBW107': 325.0, - # 'DEBY081': 735.0, '# Stations': np.nan, '# Samples': np.nan}, - # 'train': {'DEBW013': 1413, 'DEBW076': 3002, 'DEBW087': 3016, 'DEBW107': 1782, 'DEBY081': 2837, - # '# Stations': 6, '# Samples': 12050}, - # 'val': {'DEBW013': 698, 'DEBW076': 715, 'DEBW087': 700, 'DEBW107': 701, 'DEBY081': 456, - # '# Stations': 6, '# Samples': 3270}, - # 'test': {'DEBW013': 1066, 'DEBW076': 696, 'DEBW087': 1080, 'DEBW107': 1080, 'DEBY081': 700, - # '# Stations': 6, '# Samples': 4622}} - # df = pd.DataFrame.from_dict(data_dict) - # return df diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index 8f1fcd1943f9f203e738053017e00f8c269afef1..cdaa7f506d6b4b655dc582a331eea5a71b776c32 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -23,29 +23,6 @@ from mlair.run_modules.run_environment import RunEnvironment from mlair.run_modules.training import Training -def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False): - inception_model = InceptionModelBase() - conv_settings_dict1 = { - 'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation}, - 'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, } - pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation} - X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels)) - X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1) - if add_minor_branch: - out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4, - output_activation='linear', reduction_filter=64, - name='Minor_1', dropout_rate=dropout_rate, - )] - else: - out = [] - X_in = keras.layers.Dropout(dropout_rate)(X_in) - out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size, - output_activation='linear', reduction_filter=64, - name='Main', dropout_rate=dropout_rate, - )) - return keras.Model(inputs=X_input, outputs=out) - - class TestTraining: @pytest.fixture @@ -90,15 +67,6 @@ class TestTraining: RunEnvironment().__del__() except AssertionError: pass - # try: - # yield obj - # finally: - # if os.path.exists(path): - # shutil.rmtree(path) - # try: - # RunEnvironment().__del__() - # except AssertionError: - # pass @pytest.fixture def learning_rate(self): @@ -150,12 +118,16 @@ class TestTraining: return {'o3': 'dma8eu', 'temp': 'maximum'} @pytest.fixture - def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var): - data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'), + def data_origin(self): + return {'o3': 'UBA', 'temp': 'UBA'} + + @pytest.fixture + def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var, data_origin): + data_prep = DefaultDataHandler.build('DEBW107', data_path=os.path.join(path, 'data'), experiment_path=os.path.join(path, 'exp_path'), statistics_per_var=statistics_per_var, station_type="background", - network="AIRBASE", sampling="daily", target_dim="variables", - target_var="o3", time_dim="datetime", + sampling="daily", target_dim="variables", + target_var="o3", time_dim="datetime", data_origin=data_origin, window_history_size=window_history_size, window_lead_time=window_lead_time, name_affix="train") return DataCollection([data_prep])