diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 8443b10d4d16b71819b795c0579b4d61cb739b70..0e416acbca4d66d5844e1179c7653ac5a9934f28 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -114,17 +114,17 @@ class PreProcessing(RunEnvironment): +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ """ - meta_data = ['station_name', 'station_lon', 'station_lat', 'station_alt'] + meta_cols = ['station_name', 'station_lon', 'station_lat', 'station_alt'] meta_round = ["station_lon", "station_lat", "station_alt"] precision = 4 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(path) names_of_set = ["train", "val", "test"] - df = self.create_info_df(meta_data, meta_round, names_of_set, precision) + df = self.create_info_df(meta_cols, meta_round, names_of_set, precision) column_format = tables.create_column_format_for_tex(df) tables.save_to_tex(path=path, filename="station_sample_size.tex", column_format=column_format, df=df) tables.save_to_md(path=path, filename="station_sample_size.md", df=df) - df_nometa = df.drop(meta_data, axis=1) + df_nometa = df.drop(meta_cols, axis=1) column_format = tables.create_column_format_for_tex(df) tables.save_to_tex(path=path, filename="station_sample_size_short.tex", column_format=column_format, df=df_nometa) @@ -150,15 +150,35 @@ class PreProcessing(RunEnvironment): df_descr = df_descr[df_descr_colnames] return df_descr - def create_info_df(self, meta_data, meta_round, names_of_set, precision): - df = pd.DataFrame(columns=meta_data + names_of_set) + def create_info_df(self, meta_cols, meta_round, names_of_set, precision): + use_multiprocessing = self.data_store.get("use_multiprocessing") + max_process = self.data_store.get("max_number_multiprocessing") + df = pd.DataFrame(columns=meta_cols + names_of_set) for set_name in names_of_set: data = self.data_store.get("data_collection", set_name) - for station in data: - station_name = str(station.id_class) - df.loc[station_name, set_name] = station.get_Y()[0].shape[0] - if df.loc[station_name, meta_data].isnull().any(): - df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten() + n_process = min([psutil.cpu_count(logical=False), len(data), max_process]) # use only physical cpus + if n_process > 1 and use_multiprocessing is True: # parallel solution + logging.info(f"use parallel create_info_df ({set_name})") + pool = multiprocessing.Pool(n_process) + logging.info(f"running {getattr(pool, '_processes')} processes in parallel") + output = [pool.apply_async(f_proc_create_info_df, args=(station, meta_cols)) for station in data] + for i, p in enumerate(output): + res = p.get() + station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] + df.loc[station_name, set_name] = shape + if df.loc[station_name, meta_cols].isnull().any(): + df.loc[station_name, meta_cols] = meta + logging.info(f"...finished: {station_name} ({int((i + 1.) / len(output) * 100)}%)") + pool.close() + pool.join() + else: # serial solution + logging.info(f"use serial create_info_df ({set_name})") + for station in data: + res = f_proc_create_info_df(station, meta_cols) + station_name, shape, meta = res["station_name"], res["Y_shape"], res["meta"] + df.loc[station_name, set_name] = shape + if df.loc[station_name, meta_cols].isnull().any(): + df.loc[station_name, meta_cols] = meta df.loc["# Samples", set_name] = df.loc[:, set_name].sum() assert len(data) == df.loc[:, set_name].count() - 1 df.loc["# Stations", set_name] = len(data) @@ -380,6 +400,13 @@ def f_proc(data_handler, station, name_affix, store, return_strategy="", tmp_pat return _tmp_file, station +def f_proc_create_info_df(data, meta_cols): + station_name = str(data.id_class) + res = {"station_name": station_name, "Y_shape": data.get_Y()[0].shape[0], + "meta": data.id_class.meta.loc[meta_cols].values.flatten()} + return res + + def f_inspect_error(formatted): for i in range(len(formatted) - 1, -1, -1): if "mlair/mlair" not in formatted[i]: