diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 4cee4a9744f33c86e8802aad27125cf0e0b30f3a..1a0f8b8614d767b75433d480a37d0cc518e4885c 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -6,6 +6,7 @@ __date__ = '2019-11-25' import logging import os from typing import Tuple +import multiprocessing import numpy as np import pandas as pd @@ -201,6 +202,42 @@ class PreProcessing(RunEnvironment): Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. + :return: Corrected list containing only valid station IDs. + """ + t_outer = TimeTracking() + logging.info(f"check valid stations started{' (%s)' % (set_name if set_name is not None else 'all')}") + # calculate transformation using train data + if set_name == "train": + logging.info("setup transformation using train data exclusively") + self.transformation(data_handler, set_stations) + # start station check + collection = DataCollection() + valid_stations = [] + kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) + + logging.info("-------start parallel loop------") + pool = multiprocessing.Pool(4) + output = [pool.apply_async(f_proc, args=(data_handler, station, set_name, store_processed_data), kwds=kwargs) + for station in set_stations] + + for p in output: + dh, s = p.get() + if dh is not None: + collection.add(dh) + valid_stations.append(s) + + logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" + f"{len(set_stations)} valid stations.") + return collection, valid_stations + + def validate_station_old(self, data_handler: AbstractDataHandler, set_stations, set_name=None, + store_processed_data=True): + """ + Check if all given stations in `all_stations` are valid. + + Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the + loading time are logged in debug mode. + :return: Corrected list containing only valid station IDs. """ t_outer = TimeTracking() @@ -231,3 +268,12 @@ class PreProcessing(RunEnvironment): transformation_dict = data_handler.transformation(stations, **kwargs) if transformation_dict is not None: self.data_store.set("transformation", transformation_dict) + + +def f_proc(data_handler, station, name_affix, store, **kwargs): + try: + res = data_handler.build(station, name_affix=name_affix, store_processed_data=store, + **kwargs) + except (AttributeError, EmptyQueryResult): + res = None + return res, station