diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 03bf50f1bd24f15252867ac52bd633d76865155b..6f1d3566c877407830b22b71919f9850d19c4dac 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -9,25 +9,32 @@ import dask import dask.array as da import os from mlair.helpers.geofunctions import haversine_dist -from mlair.helpers.helpers import convert2xrda +from mlair.helpers.helpers import convert2xrda, remove_items from typing import Tuple, Union, List, Dict from mlair.data_handler.abstract_data_handler import AbstractDataHandler +import logging import cartopy.crs as ccrs +float_np_xr = Union[float, np.ndarray, xr.DataArray, xr.Dataset] + class WrfChemDataHandler(AbstractDataHandler): - DEFAULT_TIME_DIM = "XTIME" - DEFAULT_RECHUNK = {"XTIME": 1, "y": 36, "x": 40} + DEFAULT_LOGICAL_TIME_COORD_NAME = 'Time' DEFAULT_LOGICAL_X_COORD_NAME = 'x' DEFAULT_LOGICAL_Y_COORD_NAME = 'y' + DEFAULT_LOGICAL_Z_COORD_NAME = 'z' + + DEFAULT_PHYSICAL_TIME_COORD_NAME = "XTIME" DEFAULT_PHYSICAL_X_COORD_NAME = 'XLONG' DEFAULT_PHYSICAL_Y_COORD_NAME = 'XLAT' + DEFAULT_RECHUNK = {"XTIME": 1, "y": 36, "x": 40} - def __init__(self, data_path: str, common_file_starter: str, time_dim_name: str = DEFAULT_TIME_DIM, - rechunk_values: Dict = DEFAULT_RECHUNK, logical_x_coord_name: str = DEFAULT_LOGICAL_X_COORD_NAME, + def __init__(self, data_path: str, common_file_starter: str, time_dim_name: str = DEFAULT_LOGICAL_TIME_COORD_NAME, + rechunk_values: Dict = None, logical_x_coord_name: str = DEFAULT_LOGICAL_X_COORD_NAME, logical_y_coord_name: str = DEFAULT_LOGICAL_Y_COORD_NAME, + logical_z_coord_name: str = DEFAULT_LOGICAL_Z_COORD_NAME, physical_x_coord_name: str = DEFAULT_PHYSICAL_X_COORD_NAME, physical_y_coord_name: str = DEFAULT_PHYSICAL_Y_COORD_NAME ): @@ -35,14 +42,16 @@ class WrfChemDataHandler(AbstractDataHandler): self.data_path = data_path self.common_file_starter = common_file_starter self.time_dim_name = time_dim_name - self.rechunk_values = rechunk_values self.logical_x_coord_name = logical_x_coord_name self.logical_y_coord_name = logical_y_coord_name + self.logical_z_coord_name = logical_z_coord_name self.physical_x_coord_name = physical_x_coord_name self.physical_y_coord_name = physical_y_coord_name + if rechunk_values is None: + self.rechunk_values = {self.time_dim_name: 1} # internal self._X = None @@ -58,11 +67,10 @@ class WrfChemDataHandler(AbstractDataHandler): return os.path.join(self.data_path, self.common_file_starter + '*') def open_data(self): - print(f'open data: {self.dataset_search_str}') - # ds = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim='Time', - # parallel=True, chunks={'Time': 12}) - data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim='Time', + logging.debug(f'open data: {self.dataset_search_str}') + data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, parallel=True) + # data = data.assign_coords({'XTIME': data.XTIME.values}) self._data = data def rechunk_data(self, chunks=None, name_prefix='xarray-', token=None, lock=False): @@ -75,7 +83,6 @@ class WrfChemDataHandler(AbstractDataHandler): for l in range(1, end): combs.extend(itertools.combinations(range(start, end), l)) - # for i in [f'[{m-1}-{m}]' for m in range(2,9)]: for i in combs: filenamestr = f'{self.dataset_search_str[:-1]}{list(i)}_*' try: @@ -118,12 +125,12 @@ class DataHandlerSingleGridCoulumn(WrfChemDataHandler): DEFAULT_ITER_DIM = "points" DEFAULT_WINDOW_DIM = "window" - - def __init__(self, coords: Tuple[float, float], + def __init__(self, coords: Tuple[float_np_xr, float_np_xr], network=DEFAULT_MODEL, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, + external_coords_file: str = None, transformation=None, store_data_locally: bool = True, min_length: int = 0, start=None, end=None, variables=None, **kwargs): super().__init__(**kwargs) @@ -134,12 +141,23 @@ class DataHandlerSingleGridCoulumn(WrfChemDataHandler): self.window_lead_time = window_lead_time self._nearest_icoords = None self._nearest_coords = None + self.external_coords_file = external_coords_file self.open_data() + if external_coords_file is not None: + self._apply_external_coordinates() self.rechunk_data(self.rechunk_values) self._set_nearest_icoords(dim=[self.logical_x_coord_name, self.logical_y_coord_name]) self._set_nearest_coords() + def _apply_external_coordinates(self): + ds_coords = xr.open_dataset(self.external_coords_file, chunks={'south_north':36, 'west_east':40}) + data = self._data + for k, v in ds_coords.coords.variables.items(): + data = data.assign_coords({k: (remove_items(list(v.dims), 'Time'), v.values.squeeze())}) + self._data = data + print('setup external coords') + def _set_coords(self, coords): __set_coords = dict(lat=None, lon=None) if len(coords) != 2: @@ -177,8 +195,10 @@ class DataHandlerSingleGridCoulumn(WrfChemDataHandler): icoords = self.get_nearest_icoords() ilat = convert2xrda(np.array(icoords[self.logical_y_coord_name]), use_1d_default=True) ilon = convert2xrda(np.array(icoords[self.logical_x_coord_name]), use_1d_default=True) - lat = self._data[self.physical_y_coord_name].isel({'x': ilon, 'y': ilat}) - lon = self._data[self.physical_x_coord_name].isel({'x': ilon, 'y': ilat}) + lat = self._data[self.physical_y_coord_name].isel( + {self.logical_x_coord_name: ilon, self.logical_y_coord_name: ilat}) + lon = self._data[self.physical_x_coord_name].isel( + {self.logical_x_coord_name: ilon, self.logical_y_coord_name: ilat}) self._nearest_coords = dict(lat=lat, lon=lon) def get_nearest_coords(self, as_arrays=False): @@ -190,10 +210,6 @@ class DataHandlerSingleGridCoulumn(WrfChemDataHandler): return {k: list(v.values) for k, v in self._nearest_coords.items()} - - - - if __name__ == '__main__': def plot_map_proj(data, xlim=None, ylim=None, filename=None, point=None): @@ -228,39 +244,59 @@ if __name__ == '__main__': use_first_dummy_dataset = True if use_first_dummy_dataset: - wrf_gridcol = DataHandlerSingleGridCoulumn((lat_xr, lon_xr), - data_path='/home/felix/Data/WRF-Chem/', - common_file_starter='wrfout_d01_2010-', - time_dim_name='Time', - - ) - wrf_gridcol.get_nearest_coords() - wrf_dh = WrfChemDataHandler(data_path='/home/felix/Data/WRF-Chem/', - common_file_starter='wrfout_d01_2010-', - time_dim_name='Time', - ) - wrf_dh.open_data() - wrf_dh.rechunk_data({"XTIME": 1, "y": 36, "x": 40}) - T2 = wrf_dh._data.T2 - - icoords = dask.compute(wrf_dh.compute_nearest_icoordinates(lat_np, lon_np))[0] - - dist_np = wrf_dh.get_distances(lat_np, lon_np) - dist_xr = wrf_dh.get_distances(lat_xr, lon_xr) + wrf_new = DataHandlerSingleGridCoulumn((lat_xr, lon_xr), + data_path='/home/felix/Data/WRF-Chem/upload_aura_2021-02-24/2009/', + common_file_starter='wrfout_d0', + time_dim_name='Time', + logical_x_coord_name='west_east', + logical_y_coord_name='south_north', + logical_z_coord_name='bottom_top', + rechunk_values={'Time': 1, 'bottom_top': 2}, + external_coords_file='/home/felix/Data/WRF-Chem/upload_aura_2021-02-24/coords.nc', + ) + + # wrf_gridcol = DataHandlerSingleGridCoulumn((lat_xr, lon_xr), + # data_path='/home/felix/Data/WRF-Chem/', + # common_file_starter='wrfout_d01_2010-', + # time_dim_name='Time', + # + # ) + # wrf_gridcol.get_nearest_coords() + # wrf_dh = WrfChemDataHandler(data_path='/home/felix/Data/WRF-Chem/', + # common_file_starter='wrfout_d01_2010-', + # time_dim_name='Time', + # ) + # wrf_dh.open_data() + # wrf_dh.rechunk_data({"XTIME": 1, "y": 36, "x": 40}) + # T2 = wrf_dh._data.T2 + + icoords = dask.compute(wrf_new.compute_nearest_icoordinates(lat_np, lon_np))[0] + + dist_np = wrf_new.get_distances(lat_np, lon_np) + dist_xr = wrf_new.get_distances(lat_xr, lon_xr) dist_xr.attrs.update(dict(units='km')) dist_xr_set = xr.Dataset({'dist': dist_xr}) - for i, (data, xlim, ylim) in enumerate(((wrf_dh._data.T2, [-42, 66], [23, 80]), + for i, (data, xlim, ylim) in enumerate(((wrf_new._data.T2.isel({'Time': 0}), [-42, 66], [23, 80]), (dist_xr_set.dist, [-42, 66], [23, 80]), - (wrf_dh._data.T2.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + (wrf_new._data.T2.isel({'Time': 0}).where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), (dist_xr_set.dist.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), )): plot_map_proj(data, xlim=xlim, ylim=ylim, - point=[lat_np, lon_np], filename=f'test_dist{i}.pdf') + point=[lat_np, lon_np], filename=f'Example_dist{i}.pdf') + + for i, (data, xlim, ylim) in enumerate(((wrf_new._data.o3.isel({'Time': 0, 'bottom_top':0}), [-42, 66], [23, 80]), + (dist_xr_set.dist, [-42, 66], [23, 80]), + (wrf_new._data.o3.isel({'Time': 0, 'bottom_top': 0}).where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + (dist_xr_set.dist.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + )): + plot_map_proj(data, xlim=xlim, + ylim=ylim, + point=[lat_np, lon_np], filename=f'ExampleO3_dist{i}.pdf') ######################### # Larger 4D data - use_second_dummy_dataset = True + use_second_dummy_dataset = False if use_second_dummy_dataset: wrf_dh_4d = WrfChemDataHandler(data_path='/home/felix/Data/WRF-Chem/upload_aura/2009/2009', common_file_starter='wrfout_d01_2009',