diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index c1c2d0c5499ea617c7d3c01ff694e75378f855ff..81a1458a400fd6eb3e7e586c7f2c0f5ade039af2 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -25,6 +25,7 @@ from cartopy.geodesic import Geodesic import shapely float_np_xr = Union[float, np.ndarray, xr.DataArray, xr.Dataset] +int_or_list_of_int = Union[int, List[int]] class BaseWrfChemDataLoader: @@ -63,7 +64,7 @@ class BaseWrfChemDataLoader: staged_dimension_mapping = None, stag_ending = '_stag', ): - super().__init__() + # super().__init__() self.data_path = data_path self.common_file_starter = common_file_starter self.time_dim_name = time_dim_name @@ -76,7 +77,6 @@ class BaseWrfChemDataLoader: self.physical_y_coord_name = physical_y_coord_name self.physical_t_coord_name = physical_t_coord_name - self.staged_vars = to_list(staged_vars) self.staged_rotation_opts = staged_rotation_opts self.vars_to_rotate = vars_to_rotate @@ -120,7 +120,6 @@ class BaseWrfChemDataLoader: logging.debug(f'open data: {self.dataset_search_str}') data = xr.open_mfdataset(paths=self.dataset_search_str, combine='nested', concat_dim=self.time_dim_name, parallel=True) - # data = data.assign_coords({'XTIME': data.XTIME.values}) self._data = data def assign_coords(self, coords, **coords_kwargs): @@ -145,17 +144,11 @@ class BaseWrfChemDataLoader: try: with xr.open_mfdataset(paths=filenamestr, combine='nested', concat_dim='Time', parallel=True, chunks={'Time': 12}) as data: - # data = xr.open_mfdataset(paths=filenamestr, combine='nested', concat_dim='Time', - # parallel=True, chunks={'Time': 12}) status = 'OK' except: status = 'FAIL' print(f'{i}: {filenamestr} {status}') - @staticmethod - def plot_map(data_array): - pass - def get_distances(self, lat, lon): dist = haversine_dist(lat1=self._data.XLAT, lon1=self._data.XLONG, lat2=lat, lon2=lon) return dist @@ -242,8 +235,6 @@ class BaseWrfChemDataLoader: return new_field - - class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): DEFAULT_MODEL = "WRF-Chem" DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', @@ -257,15 +248,11 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): DEFAULT_ITER_DIM = "points" DEFAULT_WINDOW_DIM = "window" - def __init__(self, coords: Tuple[float_np_xr, float_np_xr], - network=DEFAULT_MODEL, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, - iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, - window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, + target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, + window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, - external_coords_file: str = None, - transformation=None, store_data_locally: bool = True, - min_length: int = 0, start=None, end=None, variables=None, **kwargs): + external_coords_file: str = None, **kwargs): super().__init__(**kwargs) self._set_coords(coords) self.target_dim = target_dim @@ -287,8 +274,6 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self._apply_external_coordinates() self.apply_staged_transormation() self.rechunk_data(self.rechunk_values) - # self._set_nearest_icoords(dim=[self.logical_x_coord_name, self.logical_y_coord_name]) - # self._set_nearest_coords() self._set_geoinfos() logging.debug("SingleGridColumnWrfChemDataLoader Initialised") @@ -333,7 +318,7 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): data = data.assign_coords({k: (remove_items(list(v.dims), 'Time'), v.values.squeeze())}) self._data = data ds_coords.close() - logging.info('setup external coords') + logging.debug('setup external coords') def _set_coords(self, coords): __set_coords = dict(lat=None, lon=None) @@ -390,10 +375,32 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): class DataHandlerSingleGridColumn(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - def __init__(self, *args, external_coords_file=None, **kwargs): + def __init__(self, *args, external_coords_file=None, + var_logical_z_coord_selector=None, + targetvar_logical_z_coord_selector=None, **kwargs): self.external_coords_file = external_coords_file + self.var_logical_z_coord_selector = self._ret_z_coord_select_if_valid(var_logical_z_coord_selector, + as_input=True) + self.targetvar_logical_z_coord_selector = self._ret_z_coord_select_if_valid(targetvar_logical_z_coord_selector, + as_input=False) + self._logical_z_coord_name = None super().__init__(*args, **kwargs) + @staticmethod + def _ret_z_coord_select_if_valid(z_coord: int_or_list_of_int, as_input: bool) -> int_or_list_of_int: + if isinstance(z_coord, int): + return z_coord + elif isinstance(z_coord, list): + all_entries_int = all(isinstance(zi, int) for zi in z_coord) + if all_entries_int and as_input: + return z_coord + else: + raise NotImplementedError( + f"z_coord selector of type list not implemented for target vars. " + f"`all_entries_int'={all_entries_int}, `as_input'= {as_input}") + else: + raise TypeError(f"`z_coord' must be of type int or list[int], but is {type(z_coord)}") + @staticmethod def coord_str2coords(str_coords: str, sep='__', dec_marker='_') -> Tuple[float_np_xr, float_np_xr]: if isinstance(str_coords, list) and len(str_coords) == 1: @@ -407,6 +414,10 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): def setup_data_path(self, data_path: str, sampling: str): return data_path + def extract_data_from_loader(self, loader, station): + data = loader.data.isel(loader.get_nearest_icoords()).squeeze()[self.variables] + return data + def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, store_data_locally=False, data_origin: Dict = None, start=None, end=None): @@ -416,11 +427,11 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): external_coords_file=self.external_coords_file, time_dim_name=self.time_dim ) - - # select defined variables at grid box or grid coloumn based on nearest icoords - data = sgc_loader.data.isel(sgc_loader.get_nearest_icoords()).squeeze()[self.variables] + if self._logical_z_coord_name is None: + self._logical_z_coord_name = sgc_loader.logical_z_coord_name + # # select defined variables at grid box or grid coloumn based on nearest icoords + data = self.extract_data_from_loader(sgc_loader, station) # expand dimesion for iterdim - # data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim).compute() data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim) # transpose dataarray: set first three fixed and keep remaining as is data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...) @@ -432,21 +443,32 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): # data should somehow look like this: # < xarray.DataArray(Stations: 1, datetime: 7670, variables: 9) (From DataHandlerSingleStation) meta = None - # ToDo - # data, meta = None, None - # raise NotImplementedError - # return data.chunk({self.time_dim:-1}), meta return data, meta def get_X(self, upsampling=False, as_numpy=False): if as_numpy is True: - return None + # return None + raise NotImplementedError(f"keyword argument `as_numpy=True' not implemented.") elif as_numpy is False: return self.get_transposed_history() # def get_Y(self, upsampling=False, as_numpy=False): # raise NotImplementedError + def get_transposed_history(self) -> xr.DataArray: + """Return history. + + :return: history with dimensions datetime, window, Stations, variables. + """ + return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, ...).copy() + + def get_transposed_label(self) -> xr.DataArray: + """Return label. + + :return: label with dimensions datetime*, window*, Stations, variables. + """ + return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim, ...).copy() + # def set_inputs_and_targets(self): # # inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) # # targets = self._data.sel( @@ -457,8 +479,15 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): def make_samples(self): self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) + if self.var_logical_z_coord_selector is not None: + self.history = self.history.sel({self._logical_z_coord_name: self.var_logical_z_coord_selector}) self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) self.make_observation(self.target_dim, self.target_var, self.time_dim) + if self.targetvar_logical_z_coord_selector is not None: + self.label = self.label.sel({self._logical_z_coord_name: self.targetvar_logical_z_coord_selector}) + self.observation = self.observation.sel( + {self._logical_z_coord_name: self.targetvar_logical_z_coord_selector}) + self.remove_nan(self.time_dim) # self.input_data = self.input_data.compute() # self.label = self.label.compute() @@ -484,13 +513,12 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): class DataHandlerWRF(DefaultDataHandler): - """Data handler using CDC.""" + """Data handler using DataHandlerSingleGridColumn.""" data_handler = DataHandlerSingleGridColumn data_handler_transformation = DataHandlerSingleGridColumn _requirements = data_handler.requirements() - if __name__ == '__main__': def plot_map_proj(data, xlim=None, ylim=None, filename=None, point=None, radius=None, **kwargs):