diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 8443b10d4d16b71819b795c0579b4d61cb739b70..0e416acbca4d66d5844e1179c7653ac5a9934f28 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -114,17 +114,17 @@ class PreProcessing(RunEnvironment): +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ """ - meta_data = ['station_name', 'station_lon', 'station_lat', 'station_alt'] + meta_cols = ['station_name', 'station_lon', 'station_lat', 'station_alt'] meta_round = ["station_lon", "station_lat", "station_alt"] precision = 4 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(path) names_of_set = ["train", "val", "test"] - df = self.create_info_df(meta_data, meta_round, names_of_set, precision) + df = self.create_info_df(meta_cols, meta_round, names_of_set, precision) column_format = tables.create_column_format_for_tex(df) tables.save_to_tex(path=path, filename="station_sample_size.tex", column_format=column_format, df=df) tables.save_to_md(path=path, filename="station_sample_size.md", df=df) - df_nometa = df.drop(meta_data, axis=1) + df_nometa = df.drop(meta_cols, axis=1) column_format = tables.create_column_format_for_tex(df) tables.save_to_tex(path=path, filename="station_sample_size_short.tex", column_format=column_format, df=df_nometa) @@ -150,15 +150,35 @@ class PreProcessing(RunEnvironment): df_descr = df_descr[df_descr_colnames] return df_descr - def create_info_df(self, meta_data, meta_round, names_of_set, precision): - df = pd.DataFrame(columns=meta_data + names_of_set) + def create_info_df(self, meta_cols, meta_round, names_of_set, precision): + use_multiprocessing = self.data_store.get("use_multiprocessing") + max_process = self.data_store.get("max_number_multiprocessing") + df = pd.DataFrame(columns=meta_cols + names_of_set) for set_name in names_of_set: data = self.data_store.get("data_collection", set_name) - for station in data: - station_name = str(station.id_class) - df.loc[station_name, set_name] = station.get_Y()[0].shape[0] - if df.loc[station_name, meta_data].isnull().any(): - df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten() + n_process = min([psutil.cpu_count(logical=False), len(data), max_process]) # use only physical cpus + if n_process > 1 and use_multiprocessing is True: # parallel solution + logging.info(f"use parallel create_info_df ({set_name})") + pool = multiprocessing.Pool(n_process) + logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + output = [pool.apply_async(f_proc_create_info_df, args=(station, meta_cols)) for station in data] + for i, p in enumerate(output): + res = p.get() + station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] + df.loc[station_name, set_name] = shape + if df.loc[station_name, meta_cols].isnull().any(): + df.loc[station_name, meta_cols] = meta + logging.info(f"...finished: {station_name} ({int((i + 1.) / len(output) * 100)}%)") + pool.close() + pool.join() + else: # serial solution + logging.info(f"use serial create_info_df ({set_name})") + for station in data: + res = f_proc_create_info_df(station, meta_cols) + station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] + df.loc[station_name, set_name] = shape + if df.loc[station_name, meta_cols].isnull().any(): + df.loc[station_name, meta_cols] = meta df.loc["# Samples", set_name] = df.loc[:, set_name].sum() assert len(data) == df.loc[:, set_name].count() - 1 df.loc["# Stations", set_name] = len(data) @@ -380,6 +400,13 @@ def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_pat return _tmp_file, station +def f_proc_create_info_df(data, meta_cols): + station_name = str(data.id_class) + res = {"station_name": station_name, "Y_shape": data.get_Y()[0].shape[0], + "meta": data.id_class.meta.loc[meta_cols].values.flatten()} + return res + + def f_inspect_error(formatted): for i in range(len(formatted) - 1, -1, -1): if "mlair/mlair" not in formatted[i]: diff --git a/test/test_run_modules/test_pre_processing.py b/test/test_run_modules/test_pre_processing.py index 0f2ee7a10fd2e3190c0b66da558626747d4c03c9..1dafdbd5c4882932e3d57e726e7a06bea22a745d 100644 --- a/test/test_run_modules/test_pre_processing.py +++ b/test/test_run_modules/test_pre_processing.py @@ -46,12 +46,14 @@ class TestPreProcessing: with PreProcessing(): assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)') - assert caplog.record_tuples[-3] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 ' + assert caplog.record_tuples[-6] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 ' r'station\(s\). Found 5/5 valid stations.')) + assert caplog.record_tuples[-5] == ('root', 20, "use serial create_info_df (train)") + assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (val)") + assert caplog.record_tuples[-3] == ('root', 20, "use serial create_info_df (test)") assert caplog.record_tuples[-2] == ('root', 20, "Searching for competitors to be prepared for use.") - assert caplog.record_tuples[-1] == ( - 'root', 20, "No preparation required because no competitor was provided " - "to the workflow.") + assert caplog.record_tuples[-1] == ('root', 20, "No preparation required because no competitor was provided" + " to the workflow.") RunEnvironment().__del__() def test_run(self, obj_with_exp_setup):