From 38672cc53a127c05b1e8574ae3181ce4c6cb9099 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Wed, 11 Aug 2021 17:31:03 +0200 Subject: [PATCH] move _force_dask_comp to absract dh and include input_output_sampling4toarstats --- mlair/data_handler/abstract_data_handler.py | 14 ++++++++++ mlair/data_handler/data_handler_wrf_chem.py | 17 +++++++----- mlair/data_handler/default_data_handler.py | 13 --------- mlair/model_modules/model_class.py | 2 +- run_wrf_dh.py | 29 +++++++++++++++------ 5 files changed, 47 insertions(+), 28 deletions(-) diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index 36d6e9ae..a6f49d2d 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -3,8 +3,12 @@ __author__ = 'Lukas Leufen' __date__ = '2020-09-21' import inspect +import logging from typing import Union, Dict +import dask +from dask.diagnostics import ProgressBar + from mlair.helpers import remove_items @@ -84,3 +88,13 @@ class AbstractDataHandler: def _hash_list(self): return [] + + @staticmethod + def _force_dask_computation(data): + try: + with ProgressBar(): + logging.info(f"DefaultDH: _force_dask_computation") + data = dask.compute(data)[0] + except: + logging.info("can't execute dask.compute") + return data diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 4aedd310..158a7fb2 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -612,6 +612,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): date_format_of_nc_file=None, as_image_like_data_format=True, time_zone=None, + input_output_sampling4toarstats : tuple = None, **kwargs): self.external_coords_file = external_coords_file self.var_logical_z_coord_selector = self._return_z_coord_select_if_valid(var_logical_z_coord_selector, @@ -627,7 +628,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.as_image_like_data_format = as_image_like_data_format self.time_zone = time_zone - self.tzwhere = tzwhere.tzwhere() + self.input_output_sampling4toarstats = input_output_sampling4toarstats super().__init__(*args, **kwargs) @staticmethod @@ -742,19 +743,23 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.loader.physical_y_coord_name: self.loader.get_coordinates()['lat'].tolist()} meta = pd.DataFrame(_meta, index=station) - if isinstance(self.sampling, tuple) and len(self.sampling) == 2: + if isinstance(self.input_output_sampling4toarstats, tuple) and len(self.input_output_sampling4toarstats) == 2: if self.var_logical_z_coord_selector != 0: raise NotImplementedError( f"Method `apply_toarstats` is not implemented for var_logical_z_coord_selector != 0: " f"Is {self.var_logical_z_coord_selector}") - data = self.apply_toarstats(data, meta, self.sampling[1]) + data = self.apply_toarstats(data, meta, self.input_output_sampling4toarstats[1]) data = self._slice_prep(data, start=start, end=end) return data, meta + @TimeTrackingWrapper def apply_toarstats(self, data, meta, target_sampling="daily"): - local_time_zone = self.tzwhere.tzNameAt(latitude=meta[self.loader.physical_y_coord_name], - longitude=meta[self.loader.physical_x_coord_name]) + tz_where = tzwhere.tzwhere() + local_time_zone = tz_where.tzNameAt(latitude=meta[self.loader.physical_y_coord_name], + longitude=meta[self.loader.physical_x_coord_name]) + # local_time_zone = self.tzwhere.tzNameAt(latitude=meta[self.loader.physical_y_coord_name], + # longitude=meta[self.loader.physical_x_coord_name]) hdata = data.squeeze().to_pandas() hdata.index = self.set_time_zone(hdata.index) hdata.index = hdata.index.tz_convert(local_time_zone) @@ -774,7 +779,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): for coord in missing_squeezed_coords: sampling_data.coords[coord] = data.coords[coord] - return self. sampling_data + return self._force_dask_computation(sampling_data) def set_time_zone(self, time_index): """ diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 18828370..1b4715a0 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -7,15 +7,12 @@ import inspect import gc import logging import os -import pickle import dill import shutil from functools import reduce from typing import Tuple, Union, List import multiprocessing import psutil -import dask -from dask.diagnostics import ProgressBar import numpy as np import xarray as xr @@ -111,16 +108,6 @@ class DefaultDataHandler(AbstractDataHandler): attr_dict[attr] = val return attr_dict - @staticmethod - def _force_dask_computation(data): - try: - with ProgressBar(): - logging.info(f"DefaultDH: _force_dask_computation") - data = dask.compute(data)[0] - except: - logging.info("can't execute dask.compute") - return data - def _load(self): try: with open(self._save_file, "rb") as f: diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index 8f464b16..fd43674e 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -518,7 +518,7 @@ class MyLuongAttentionLSTMModel(AbstractModelClass): self.initial_lr = 0.01 self.clipnorm = 1 - self.n_hidden = 100 + self.n_hidden = 32 # apply to model self.set_model() diff --git a/run_wrf_dh.py b/run_wrf_dh.py index a406f193..bd5aa802 100644 --- a/run_wrf_dh.py +++ b/run_wrf_dh.py @@ -8,7 +8,7 @@ from mlair.workflows import DefaultWorkflow from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_PLOT_LIST -from mlair.model_modules.model_class import MyLSTMModel, MyCNNModel +from mlair.model_modules.model_class import MyLSTMModel, MyCNNModel, MyLuongAttentionLSTMModel import os @@ -42,7 +42,7 @@ def main(parser_args): # data_handler=DataHandlerSingleStation, # data_handler=DataHandlerSingleGridColumn, upsampeling=False, - epochs=20, + epochs=200, window_lead_time=4, window_history_size=6, # ('Germany', (5.98865807458, 47.3024876979, 15.0169958839, 54.983104153)) @@ -90,7 +90,7 @@ def main(parser_args): 'no2': {"method": "standardise"}, 'co': {"method": "standardise"}, 'PSFC': {"method": "standardise"}, - 'CLDFRA': {"method": "min_max"}, + 'CLDFRA': {"method": "min_max","min": 0., "max": 1.}, }, variables=['T2', 'o3', 'wdir10ll', 'wspd10ll', 'no', 'no2', 'co', 'PSFC', 'PBLH', 'CLDFRA'], target_var='o3', @@ -120,6 +120,8 @@ def main(parser_args): # train_end='2009-01-02', # train_start='2009-01-01', #train_end='2009-01-15', + # train_end='2009-02-15', + # train_end="2009-01-31", train_end='2009-02-28', # val_start='2009-01-02', @@ -130,26 +132,37 @@ def main(parser_args): ################################### val_start='2009-03-01', val_end='2009-03-14', + ################################### + # val_start='2009-02-15', + # val_end='2009-03-02', + ################################### + # val_start='2009-02-01', + # val_end='2009-02-28', # test_start='2009-01-03', # test_end='2009-01-04', ################################### - #test_start='2009-01-22', - #test_end='2009-01-31', + # test_start='2009-01-22', + # test_end='2009-01-31', ################################### test_start='2009-03-15', test_end='2009-03-31', + ################################### + # test_start='2009-03-02', + # test_end='2009-03-31', # sampling='hourly', - sampling=("hourly", "daily"), + sampling="daily", + input_output_sampling4toarstats=("hourly", "daily"), interpolation_limit=0, - # as_image_like_data_format=False, + as_image_like_data_format=False, # model=MyLSTMModel, + model=MyLuongAttentionLSTMModel, use_multiprocessing=True, # as_image_like_data_format=True, # model=MyLSTMModel, - model=MyCNNModel, + # model=MyCNNModel, **parser_args.__dict__) workflow.run() -- GitLab