diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 0d625948da27f18ca70fe4ac4f4ec3dc1ebb0b88..ae57c93f9d303fec4df3ba7c5dd051e064374d85 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -7,34 +7,54 @@ import itertools import matplotlib.pyplot as plt import dask import dask.array as da +import os from mlair.helpers.geofunctions import haversine_dist +from typing import Tuple, Union, List, Dict from mlair.data_handler.abstract_data_handler import AbstractDataHandler import cartopy.crs as ccrs class WrfChemDataHandler(AbstractDataHandler): - - def __init__(self, data_path, common_file_starter, target_dim=None, target_var=None, time_dim=None, - window_history_size=None, window_lead_time=None, + DEFAULT_TIME_DIM = "XTIME" + DEFAULT_RECHUNK = {"XTIME": 1, "y": 36, "x": 40} + DEFAULT_LOGICAL_X_COORD_NAME = 'x' + DEFAULT_LOGICAL_Y_COORD_NAME = 'y' + DEFAULT_PHYSICAL_X_COORD_NAME = 'XLONG' + DEFAULT_PHYSICAL_Y_COORD_NAME = 'XLAT' + + + def __init__(self, data_path: str, common_file_starter: str, time_dim_name: str = DEFAULT_TIME_DIM, + rechunk_values: Dict = DEFAULT_RECHUNK, logical_x_coord_name: str = DEFAULT_LOGICAL_X_COORD_NAME, + logical_y_coord_name: str = DEFAULT_LOGICAL_Y_COORD_NAME, + physical_x_coord_name: str = DEFAULT_PHYSICAL_X_COORD_NAME, + physical_y_coord_name: str = DEFAULT_PHYSICAL_Y_COORD_NAME ): super().__init__() self.data_path = data_path self.common_file_starter = common_file_starter - self.target_dim = target_dim - self.target_var = target_var - self.time_dim = time_dim - self.window_history_size = window_history_size - self.window_lead_time = window_lead_time + self.time_dim_name = time_dim_name + self.rechunk_values = rechunk_values + + self.logical_x_coord_name = logical_x_coord_name + self.logical_y_coord_name = logical_y_coord_name + + self.physical_x_coord_name = physical_x_coord_name + self.physical_y_coord_name = physical_y_coord_name + # internal self._X = None self._Y = None self._data = None + @property + def data(self) -> xr.Dataset: + return self._data + @property def dataset_search_str(self): - return self.data_path + self.common_file_starter + '*' + return os.path.join(self.data_path, self.common_file_starter + '*') def open_data(self): print(f'open data: {self.dataset_search_str}') @@ -75,7 +95,7 @@ class WrfChemDataHandler(AbstractDataHandler): dist = haversine_dist(lat1=self._data.XLAT, lon1=self._data.XLONG, lat2=lat, lon2=lon) return dist - def get_nearest_coordinates(self, lat, lon, dim=None): + def compute_nearest_icoordinates(self, lat, lon, dim=None): dist = self.get_distances(lat=lat, lon=lon) if dim is None: @@ -84,78 +104,89 @@ class WrfChemDataHandler(AbstractDataHandler): return dist.argmin(dim) -def haversine_dist_single_point(lat1, lon1, lat2, lon2, to_radians=True, earth_radius=6371.): - """ - Calculate the great circle distance between two points - on the Earth (specified in decimal degrees or in radians) +class DataHandlerSingleGridCoulumn(WrfChemDataHandler): + DEFAULT_MODEL = "WRF-Chem" + DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', + 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', + 'pblheight': 'maximum'} + DEFAULT_WINDOW_LEAD_TIME = 3 + DEFAULT_WINDOW_HISTORY_SIZE = 13 + DEFAULT_WINDOW_HISTORY_OFFSET = 0 + DEFAULT_TARGET_VAR = "o3" + DEFAULT_TARGET_DIM = "variables" + DEFAULT_ITER_DIM = "points" + DEFAULT_WINDOW_DIM = "window" + + + def __init__(self, coords: Tuple[float, float], + network=DEFAULT_MODEL, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, + iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, + window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, + window_lead_time=DEFAULT_WINDOW_LEAD_TIME, + transformation=None, store_data_locally: bool = True, + min_length: int = 0, start=None, end=None, variables=None, **kwargs): + super().__init__(**kwargs) + self._set_coords(coords) + self.target_dim = target_dim + self.target_var = target_var + self.window_history_size = window_history_size + self.window_lead_time = window_lead_time + + self.open_data() + self.rechunk_data(self.rechunk_values) + self._set_nearest_icoords(dim=[self.logical_x_coord_name, self.logical_y_coord_name]) + self._set_nearest_coords() + + def _set_coords(self, coords): + __set_coords = dict(lat=None, lon=None) + if len(coords) != 2: + raise SyntaxError(f"`coords' must have length=2 (lat, lon), but has length={len(coords)}") + if isinstance(coords, tuple): + self.__coords = dict(lat=coords[0], lon=coords[1]) + print(f"self.__coords={self.__coords}") + elif isinstance(coords, dict): + if (coords.keys() == __set_coords.keys()): + self.__coords = coords + else: + raise KeyError(f"dict `coords' must have the keys `lat', and `lon' but has: {coords.keys()}") + else: + raise TypeError(f"`coords' must be a tuple of floats or a dict, but is of type: {type(coords)}") + + def get_coordinates(self, as_arrays=False) -> Union[Tuple[np.ndarray, np.ndarray], dict]: + if as_arrays: + return np.array(self.__coords['lat']), np.array(self.__coords['lon']) + else: + return self.__coords + + def _set_nearest_icoords(self, dim=None): + lat, lon = self.get_coordinates(as_arrays=True) + self.__nearest_icoords = dask.compute(self.compute_nearest_icoordinates(lat, lon, dim))[0] + + def get_nearest_icoords(self, as_arrays=False): + if as_arrays: + return (self.__nearest_icoords[self.logical_y_coord_name].values, + self.__nearest_icoords[self.logical_x_coord_name].values, + ) + else: + return {k: list(v.values) for k, v in self.__nearest_icoords.items()} + + def _set_nearest_coords(self): + icoords = self.get_nearest_icoords() + lat = self._data[self.physical_y_coord_name].sel(icoords) + lon = self._data[self.physical_x_coord_name].sel(icoords) + self.__nearest_coords = dask.compute(dict(lat=lat, lon=lon))[0] + + def get_nearest_coords(self, as_arrays=False): + if as_arrays: + return (self.__nearest_coords['lat'].values, + self.__nearest_coords['lon'].values, + ) + else: + return {k: v.values for k, v in self.__nearest_coords.items()} + - All (lat, lon) coordinates must have numeric dtypes and be of equal length. - An exception holds if one pair is given as a scalar, in this case broadcasting is possible - :param lat1: Latitude(-s) of first location - :param lon1: Longitude(-s) of first location - :param lat2: Latitude(-s) of second location - :param lon2: Longitude(-s) of second location - :param to_radians: Flag if conversion from degree to radiant is required - :param earth_radius: Earth radius in kilometers - :return: Distance between locations in kilometers - """ - if to_radians: - pass - a = da.sin((lat2 - lat1) / 2.0) ** 2 + \ - da.cos(lat1) * da.cos(lat2) * da.sin((lon2 - lon1) / 2.0) ** 2 - - return earth_radius * 2. * da.arcsin(da.sqrt(a)) - - -def greatcircle_dist(lat1, lon1, lat2, lon2, to_radians=True, earth_radius=6371.): - """ - - :param lat1: - :type lat1: - :param lon1: - :type lon1: - :param lat2: - :type lat2: - :param lon2: - :type lon2: - :param to_radians: - :type to_radians: - :param earth_radius: - :type earth_radius: - :return: - :rtype: - """ - if to_radians: - lat1, lon1, lat2, lon2 = np.deg2rad(lat1), np.deg2rad(lon1), np.deg2rad(lat2), np.deg2rad(lon2) - - del_sig = np.arccos( - np.sin(lat1) * np.sin(lat2) + np.cos(lat1) * np.cos(lat2) * np.cos(np.abs(lon2 - lon1)) - ) - dist = earth_radius * del_sig - return dist - - -def kdtree_fast(latvar, lonvar, lat0, lon0): - from scipy.spatial import cKDTree - rad_factor = np.pi / 180.0 # for trignometry, need angles in radians - # Read latitude and longitude from file into numpy arrays - latvals = np.deg2rad(latvar) - lonvals = np.deg2rad(lonvar) - ny, nx = latvals.shape - clat, clon = np.cos(latvals), np.cos(lonvals) - slat, slon = np.sin(latvals), np.sin(lonvals) - # Build kd-tree from big arrays of 3D coordinates - triples = list(zip(np.ravel(clat * clon), np.ravel(clat * slon), np.ravel(slat))) - kdt = cKDTree(triples) - lat0_rad = np.deg2rad(lat0) - lon0_rad = np.deg2rad(lon0) - clat0, clon0 = np.cos(lat0_rad), np.cos(lon0_rad) - slat0, slon0 = np.sin(lat0_rad), np.sin(lon0_rad) - dist_sq_min, minindex_1d = kdt.query([clat0 * clon0, clat0 * slon0, slat0]) - iy_min, ix_min = np.unravel_index(minindex_1d, latvals.shape) - return iy_min, ix_min if __name__ == '__main__': @@ -183,59 +214,73 @@ if __name__ == '__main__': plt.close('all') - wrf_dh = WrfChemDataHandler(data_path='/home/felix/Data/WRF-Chem/', - common_file_starter='wrfout_d01_2010-', - window_history_size=12, - window_lead_time=4, - time_dim='Time', - ) - wrf_dh.open_data() - wrf_dh.rechunk_data({"XTIME": 1, "y": 36, "x": 40}) - T2 = wrf_dh._data.T2 + lat_np = np.array([50.73333, 45.0]) + lon_np = np.array([7.1, 0.0]) - lat_np = np.array([50.73333]) - lon_np = np.array([7.1]) - icoords = dask.compute(wrf_dh.get_nearest_coordinates(lat_np, lon_np))[0] lat_xr = xr.DataArray(lat_np, dims=["points"], coords={'points': range(len(lat_np))}) lon_xr = xr.DataArray(lon_np, dims=["points"], coords={'points': range(len(lon_np))}) - dist_np = wrf_dh.get_distances(lat_np, lon_np) - dist_xr = wrf_dh.get_distances(lat_xr, lon_xr) - dist_xr.attrs.update(dict(units='km')) - dist_xr_set = xr.Dataset({'dist': dist_xr}) - - for i, (data, xlim, ylim) in enumerate(((wrf_dh._data.T2, [-42, 66], [23, 80]), - (dist_xr_set.dist, [-42, 66], [23, 80]), - (wrf_dh._data.T2.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), - (dist_xr_set.dist.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), - )): - plot_map_proj(data, xlim=xlim, - ylim=ylim, - point=[lat_np, lon_np], filename=f'test_dist{i}.pdf') - - - plot_map_proj(wrf_dh._data.T2.where(dist_xr.sel({'points': 0}).drop('points') <= 100), xlim=[-0, 15], ylim=[40, 58], - point=[lat_np, lon_np], filename='test_dist.pdf') - wrf_dh_18 = WrfChemDataHandler(data_path='/home/felix/Data/WRF-Chem/test_data_aura/', - common_file_starter='wrfout_d01_2018-08-01', - window_history_size=12, - window_lead_time=4, - time_dim='Time', - ) - wrf_dh_18.open_data() - wrf_dh_18.rechunk_data({"Time": 1, "south_north": 36, "west_east": 40}) - T2_18 = wrf_dh_18._data.T2 - T2_18.isel(Time=0).plot() - plt.savefig('test_fig2.pdf') - plt.close('all') - - # plot_map_proj(T2_18.isel(Time=0), [3, 18], [45, 57], filename='test_fig2.pdf') - # - # dist1_18_hav = haversine_dist_single_point(T2_18.XLAT, T2_18.XLONG, 50.733, 7.10) - # dist2_18_hav = haversine_dist_single_point(50.733, 7.10, T2_18.XLAT, T2_18.XLONG) - - icoords_18 = dask.compute(wrf_dh_18.get_nearest_coordinates(lat=50.733, lon=7.1, dim=['south_north', 'west_east']))[ - 0] - T2val = wrf_dh_18._data.isel(icoords_18).T2.values + use_first_dummy_dataset = True + if use_first_dummy_dataset: + wrf_dh = WrfChemDataHandler(data_path='/home/felix/Data/WRF-Chem/', + common_file_starter='wrfout_d01_2010-', + time_dim_name='Time', + ) + wrf_gridcol = DataHandlerSingleGridCoulumn((lat_xr, lon_xr), + data_path='/home/felix/Data/WRF-Chem/', + common_file_starter='wrfout_d01_2010-', + time_dim_name='Time', + + ) + wrf_dh.open_data() + wrf_dh.rechunk_data({"XTIME": 1, "y": 36, "x": 40}) + T2 = wrf_dh._data.T2 + + icoords = dask.compute(wrf_dh.compute_nearest_icoordinates(lat_np, lon_np))[0] + + dist_np = wrf_dh.get_distances(lat_np, lon_np) + dist_xr = wrf_dh.get_distances(lat_xr, lon_xr) + dist_xr.attrs.update(dict(units='km')) + dist_xr_set = xr.Dataset({'dist': dist_xr}) + + for i, (data, xlim, ylim) in enumerate(((wrf_dh._data.T2, [-42, 66], [23, 80]), + (dist_xr_set.dist, [-42, 66], [23, 80]), + (wrf_dh._data.T2.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + (dist_xr_set.dist.where(dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], [45, 58]), + )): + plot_map_proj(data, xlim=xlim, + ylim=ylim, + point=[lat_np, lon_np], filename=f'test_dist{i}.pdf') + + ######################### # Larger 4D data + use_second_dummy_dataset = True + if use_second_dummy_dataset: + wrf_dh_4d = WrfChemDataHandler(data_path='/home/felix/Data/WRF-Chem/upload_aura/2009/2009', + common_file_starter='wrfout_d01_2009', + time_dim_name='Time', + ) + wrf_dh_4d.open_data() + wrf_dh_4d.rechunk_data({"Time": 1, "bottom_top": 34, "south_north": 36, "west_east": 40}) + lat_np = np.array([50.73333]) + lon_np = np.array([7.1]) + wrf_dh_4d._data = wrf_dh_4d._data.assign_coords(wrf_dh._data.coords) + icoords = dask.compute(wrf_dh_4d.compute_nearest_icoordinates(lat_np, lon_np))[0] + + dist_xr = wrf_dh_4d.get_distances(lat_xr, lon_xr) + dist_xr.attrs.update(dict(units='km')) + dist_xr_set = xr.Dataset({'dist': dist_xr}) + + for i, (data, xlim, ylim) in enumerate(((wrf_dh_4d._data.T2, [-42, 66], [23, 80]), + (dist_xr_set.dist, [-42, 66], [23, 80]), + (wrf_dh_4d._data.T2.where( + dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], + [45, 58]), + (dist_xr_set.dist.where( + dist_xr.sel({'points': 0}).drop('points') <= 100), [2, 15], + [45, 58]), + )): + plot_map_proj(data, xlim=xlim, + ylim=ylim, + point=[lat_np, lon_np], filename=f'test_dist{i}.pdf') print()