diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 59348b30413c87ec3838af3ea333a294e98cc7f3..ec405db9ff94380688eb7885d65160b8678e0d2a 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -6,6 +6,7 @@ import numpy as np import itertools import matplotlib.pyplot as plt import dask +import inspect import dask.array as da import os from mlair.helpers.geofunctions import haversine_dist @@ -23,17 +24,19 @@ float_np_xr = Union[float, np.ndarray, xr.DataArray, xr.Dataset] class BaseWrfChemDataLoader: DEFAULT_LOGICAL_TIME_COORD_NAME = 'Time' - DEFAULT_LOGICAL_X_COORD_NAME = 'x' - DEFAULT_LOGICAL_Y_COORD_NAME = 'y' - DEFAULT_LOGICAL_Z_COORD_NAME = 'z' + DEFAULT_LOGICAL_X_COORD_NAME = 'west_east' + DEFAULT_LOGICAL_Y_COORD_NAME = 'south_north' + DEFAULT_LOGICAL_Z_COORD_NAME = 'bottom_top' DEFAULT_PHYSICAL_TIME_COORD_NAME = "XTIME" DEFAULT_PHYSICAL_X_COORD_NAME = 'XLONG' DEFAULT_PHYSICAL_Y_COORD_NAME = 'XLAT' - DEFAULT_RECHUNK = {"XTIME": 1, "y": 36, "x": 40} + DEFAULT_RECHUNK = {"Time": -1, "y": 36, "x": 40} + DEFAULT_FILE_STARTER = 'wrfout_d0' - def __init__(self, data_path: str, common_file_starter: str, time_dim_name: str = DEFAULT_LOGICAL_TIME_COORD_NAME, + def __init__(self, data_path: str, common_file_starter: str = DEFAULT_FILE_STARTER, + time_dim_name: str = DEFAULT_LOGICAL_TIME_COORD_NAME, rechunk_values: Dict = None, logical_x_coord_name: str = DEFAULT_LOGICAL_X_COORD_NAME, logical_y_coord_name: str = DEFAULT_LOGICAL_Y_COORD_NAME, @@ -125,6 +128,15 @@ class BaseWrfChemDataLoader: else: return dist.argmin(dim) + def _set_dims_as_coords(self): + if self._data is None: + raise IOError(f'{self.__class__.__name__} can not set dims as coords. Use must use `open_data()` before.') + data = self._data + for k, _ in data.dims.items(): + data = data.assign_coords({k: data[k]}) + self._data = data + logging.info('set dimensions as coordinates') + class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): DEFAULT_MODEL = "WRF-Chem" @@ -160,6 +172,8 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self.open_data() self.assign_coords( {self.physical_t_coord_name: (self.time_dim_name, self._data[self.physical_t_coord_name].values)}) + + self._set_dims_as_coords() if external_coords_file is not None: self._apply_external_coordinates() self.rechunk_data(self.rechunk_values) @@ -167,12 +181,12 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self._set_nearest_coords() def _apply_external_coordinates(self): - ds_coords = xr.open_dataset(self.external_coords_file, chunks={'south_north':36, 'west_east':40}) + ds_coords = xr.open_dataset(self.external_coords_file, chunks={'south_north': 36, 'west_east': 40}) data = self._data for k, v in ds_coords.coords.variables.items(): data = data.assign_coords({k: (remove_items(list(v.dims), 'Time'), v.values.squeeze())}) self._data = data - print('setup external coords') + logging.info('setup external coords') def _set_coords(self, coords): __set_coords = dict(lat=None, lon=None) @@ -226,60 +240,75 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): return {k: list(v.values) for k, v in self._nearest_coords.items()} -class DataHandlerSingleGridCoulumn2(SingleGridColumnWrfChemDataLoader, DataHandlerSingleStation): +class DataHandlerSingleGridColumn2(DataHandlerSingleStation): + _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - def __init__(self, common_file_starter, wrf_kwargs=None, **kwargs): - super().__init__() + def __init__(self, *args, external_coords_file, **kwargs): + self.external_coords_file = external_coords_file + super().__init__(*args, **kwargs) + + @staticmethod + def str2coords(str_coords: str, sep='-') -> Tuple[float_np_xr, float_np_xr]: + if isinstance(str_coords, list) and len(str_coords) == 1: + str_coords = str_coords[0] + lat, lon = str_coords.split(sep=sep) + return np.array(float(lat)), np.array(float(lon)) + def setup_data_path(self, data_path: str, sampling: str): + return data_path def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, - store_data_locally=False, data_origin: Dict = None, start = None, end = None): - data = SingleGridColumnWrfChemDataLoader((lat_xr, lon_xr), - data_path='/home/felix/Data/WRF-Chem/upload_aura_2021-02-24/2009_?/', - common_file_starter='wrfout_d0', - time_dim_name='Time', - logical_x_coord_name='west_east', - logical_y_coord_name='south_north', - logical_z_coord_name='bottom_top', - rechunk_values={'Time': 1, 'bottom_top': 2}, - external_coords_file='/home/felix/Data/WRF-Chem/upload_aura_2021-02-24/coords.nc', - ) - - #ToDo - data, meta = None, None - raise NotImplementedError - #return data, meta - - def set_inputs_and_targets(self): - # inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) - # targets = self._data.sel( - # {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim?? - # self.input_data = inputs - # self.target_data = targets - raise NotImplementedError - - def make_samples(self): - # self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) - # self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) - # self.make_observation(self.target_dim, self.target_var, self.time_dim) - # self.remove_nan(self.time_dim) - raise NotImplementedError - - @TimeTrackingWrapper - def setup_samples(self): - """ - Setup samples. This method prepares and creates samples X, and labels Y. - """ - # data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - # self.station_type, self.network, self.store_data_locally, self.data_origin, - # self.start, self.end) - # self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, - # limit=self.interpolation_limit) - # self.set_inputs_and_targets() - # if self.do_transformation is True: - # self.call_transform() - # self.make_samples() - raise NotImplementedError + store_data_locally=False, data_origin: Dict = None, start=None, end=None): + + lat, lon = self.str2coords(station) + sgc_loader = SingleGridColumnWrfChemDataLoader((lat, lon), + data_path=path, + rechunk_values={'Time': 1, 'bottom_top': 2}, + external_coords_file=self.external_coords_file, + ) + + data = sgc_loader.data.isel(sgc_loader.get_nearest_icoords()).squeeze()[self.variables] + data = data.expand_dims({self.iter_dim: station}).to_array() + data = data.transpose(self.iter_dim, self.time_dim, 'variable', sgc_loader.logical_z_coord_name) + # ToDo + # data should somehow look like this: + # < xarray.DataArray(Stations: 1, datetime: 7670, variables: 9) (From DataHandlerSingleStation) + meta = None + # ToDo + # data, meta = None, None + # raise NotImplementedError + return data.chunk({self.time_dim:-1}), meta + + # def set_inputs_and_targets(self): + # # inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) + # # targets = self._data.sel( + # # {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim?? + # # self.input_data = inputs + # # self.target_data = targets + # raise NotImplementedError + + # def make_samples(self): + # # self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) + # # self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) + # # self.make_observation(self.target_dim, self.target_var, self.time_dim) + # # self.remove_nan(self.time_dim) + # raise NotImplementedError + + # @TimeTrackingWrapper + # def setup_samples(self): + # """ + # Setup samples. This method prepares and creates samples X, and labels Y. + # """ + # data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + # self.station_type, self.network, self.store_data_locally, self.data_origin, + # self.start, self.end) + # self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + # limit=self.interpolation_limit) + # # self.set_inputs_and_targets() + # # if self.do_transformation is True: + # # self.call_transform() + # # self.make_samples() + # raise NotImplementedError if __name__ == '__main__': @@ -351,18 +380,24 @@ if __name__ == '__main__': for i, (data, xlim, ylim) in enumerate(((wrf_new._data.T2.isel({'Time': 0}), [-42, 66], [23, 80]), (dist_xr_set.dist, [-42, 66], [23, 80]), - (wrf_new._data.T2.isel({'Time': 0}).where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), - (dist_xr_set.dist.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + (wrf_new._data.T2.isel({'Time': 0}).where( + dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], + [45, 58]), + (dist_xr_set.dist.where( + dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], + [45, 58]), )): plot_map_proj(data, xlim=xlim, ylim=ylim, point=[lat_np, lon_np], filename=f'Example_dist{i}.pdf') - for i, (data, xlim, ylim) in enumerate(((wrf_new._data.o3.isel({'Time': 0, 'bottom_top':0}), [-42, 66], [23, 80]), - (dist_xr_set.dist, [-42, 66], [23, 80]), - (wrf_new._data.o3.isel({'Time': 0, 'bottom_top': 0}).where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), - (dist_xr_set.dist.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), - )): + for i, (data, xlim, ylim) in enumerate( + ((wrf_new._data.o3.isel({'Time': 0, 'bottom_top': 0}), [-42, 66], [23, 80]), + (dist_xr_set.dist, [-42, 66], [23, 80]), + (wrf_new._data.o3.isel({'Time': 0, 'bottom_top': 0}).where( + dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + (dist_xr_set.dist.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + )): plot_map_proj(data, xlim=xlim, ylim=ylim, point=[lat_np, lon_np], filename=f'ExampleO3_dist{i}.pdf') @@ -371,9 +406,7 @@ if __name__ == '__main__': use_second_dummy_dataset = False if use_second_dummy_dataset: wrf_dh_4d = BaseWrfChemDataLoader(data_path='/home/felix/Data/WRF-Chem/upload_aura/2009/2009', - common_file_starter='wrfout_d01_2009', - time_dim_name='Time', - ) + common_file_starter='wrfout_d01_2009', time_dim_name='Time') wrf_dh_4d.open_data() wrf_dh_4d.rechunk_data({"Time": 1, "bottom_top": 34, "south_north": 36, "west_east": 40}) lat_np = np.array([50.73333])