diff --git a/mlair/data_handler/data_handler_wrf_chem.py b/mlair/data_handler/data_handler_wrf_chem.py index 227a609b33c9ecd10c88cd89147e366fe8c34539..409a31e476448413f562df319a4c840df39d9fb0 100644 --- a/mlair/data_handler/data_handler_wrf_chem.py +++ b/mlair/data_handler/data_handler_wrf_chem.py @@ -9,8 +9,11 @@ import matplotlib.pyplot as plt import dask import inspect import os +import gc from mlair.helpers.geofunctions import haversine_dist, bearing_angle, WindSector, VectorRotateLambertConformal2latlon from mlair.helpers.helpers import convert2xrda, remove_items, to_list +from mlair import helpers + from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.data_handler import DefaultDataHandler @@ -63,7 +66,7 @@ class BaseWrfChemDataLoader: staged_vars: List[str] = DEFAULT_STAGED_VARS, staged_rotation_opts: Dict = DEFAULT_STAGED_ROTATION_opts, vars_to_rotate: Tuple[Tuple[Tuple[str, str], Tuple[str, str]]] = DEFAULT_VARS_TO_ROTATE, - staged_dimension_mapping = None, stag_ending = '_stag', + staged_dimension_mapping=None, stag_ending='_stag', ): # super().__init__() @@ -265,6 +268,21 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): self._nearest_coords = None self.external_coords_file = external_coords_file + # self.open_data() + # + # if self.physical_t_coord_name != self.time_dim_name: + # self.assign_coords( + # {self.physical_t_coord_name: (self.time_dim_name, self._data[self.physical_t_coord_name].values)}) + # + # self._set_dims_as_coords() + # if external_coords_file is not None: + # self._apply_external_coordinates() + # self.apply_staged_transormation() + # self.rechunk_data(self.rechunk_values) + # self._set_geoinfos() + logging.debug("SingleGridColumnWrfChemDataLoader Initialised") + + def __enter__(self): self.open_data() if self.physical_t_coord_name != self.time_dim_name: @@ -272,12 +290,16 @@ class SingleGridColumnWrfChemDataLoader(BaseWrfChemDataLoader): {self.physical_t_coord_name: (self.time_dim_name, self._data[self.physical_t_coord_name].values)}) self._set_dims_as_coords() - if external_coords_file is not None: + if self.external_coords_file is not None: self._apply_external_coordinates() self.apply_staged_transormation() self.rechunk_data(self.rechunk_values) self._set_geoinfos() - logging.debug("SingleGridColumnWrfChemDataLoader Initialised") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.data.close() + gc.collect() def _set_geoinfos(self): # identify nearest coords @@ -388,6 +410,7 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self._logical_z_coord_name = None self._joint_z_coord_selector = self._extract_largest_coord_extractor(self.var_logical_z_coord_selector, self.targetvar_logical_z_coord_selector) + self.__loader = None super().__init__(*args, **kwargs) @staticmethod @@ -418,38 +441,51 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): def setup_data_path(self, data_path: str, sampling: str): return data_path - def extract_data_from_loader(self, loader, station): + def extract_data_from_loader(self, loader): data = loader.data.isel(loader.get_nearest_icoords()).squeeze()[self.variables] return data + @property + def loader(self): + return self.__loader + + @loader.setter + def loader(self, station_path): + try: + station, path = station_path + except ValueError as e: + raise ValueError(f"Pass an iterable with two items; (station, path)") + lat, lon = self.coord_str2coords(station) + loader = SingleGridColumnWrfChemDataLoader((lat, lon), + data_path=path, + external_coords_file=self.external_coords_file, + time_dim_name=self.time_dim + ) + self.__loader = loader + def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, store_data_locally=False, data_origin: Dict = None, start=None, end=None): - lat, lon = self.coord_str2coords(station) - sgc_loader = SingleGridColumnWrfChemDataLoader((lat, lon), - data_path=path, - external_coords_file=self.external_coords_file, - time_dim_name=self.time_dim - ) - if self._logical_z_coord_name is None: - self._logical_z_coord_name = sgc_loader.logical_z_coord_name - # # select defined variables at grid box or grid coloumn based on nearest icoords - data = self.extract_data_from_loader(sgc_loader, station) - if self._joint_z_coord_selector is not None: - data = data.sel({self._logical_z_coord_name: self._joint_z_coord_selector}) - # expand dimesion for iterdim - data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim) - # transpose dataarray: set first three fixed and keep remaining as is - data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...) - - data = dask.compute(self._slice_prep(data, start=start, end=end))[0] - sgc_loader.data.close() - # data = self.check_for_negative_concentrations(data) - - # ToDo - # data should somehow look like this: - # < xarray.DataArray(Stations: 1, datetime: 7670, variables: 9) (From DataHandlerSingleStation) - meta = None + self.loader = (station, path) + + with self.loader as loader: + if self._logical_z_coord_name is None: + self._logical_z_coord_name = loader.logical_z_coord_name + # # select defined variables at grid box or grid coloumn based on nearest icoords + data = self.extract_data_from_loader(loader) + if self._joint_z_coord_selector is not None: + data = data.sel({self._logical_z_coord_name: self._joint_z_coord_selector}) + # expand dimesion for iterdim + data = data.expand_dims({self.iter_dim: station}).to_array(self.target_dim) + # transpose dataarray: set first three fixed and keep remaining as is + data = data.transpose(self.iter_dim, self.time_dim, self.target_dim, ...) + + data = dask.compute(self._slice_prep(data, start=start, end=end))[0] + # ToDo + # data should somehow look like this: + # < xarray.DataArray(Stations: 1, datetime: 7670, variables: 9) (From DataHandlerSingleStation) + meta = None + return data, meta @staticmethod @@ -496,19 +532,46 @@ class DataHandlerSingleGridColumn(DataHandlerSingleStation): self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) if self.var_logical_z_coord_selector is not None: self.history = self.history.sel({self._logical_z_coord_name: self.var_logical_z_coord_selector}) + self.history = self.modify_history() self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) self.make_observation(self.target_dim, self.target_var, self.time_dim) if self.targetvar_logical_z_coord_selector is not None: self.label = self.label.sel({self._logical_z_coord_name: self.targetvar_logical_z_coord_selector}) self.observation = self.observation.sel( {self._logical_z_coord_name: self.targetvar_logical_z_coord_selector}) + self.label = self.modify_label() + self.observation = self.modify_observation() self.remove_nan(self.time_dim) # self.input_data = self.input_data.compute() # self.label = self.label.compute() # self.observation = self.observation.compute() # self.target_data = self.target_data.compute() - self._data.close() + # self._data.close() + + def modify_history(self): + """ + Place holder for more user spec. processing of history. + :return: + :rtype: + """ + return self.history + + def modify_label(self): + """ + Place holder for more user spec. processing of label. + :return: + :rtype: + """ + return self.label + + def modify_observation(self): + """ + Place holder for more user spec. processing of observation. + :return: + :rtype: + """ + return self.observation # @TimeTrackingWrapper # def setup_samples(self): @@ -534,6 +597,78 @@ class DataHandlerWRF(DefaultDataHandler): _requirements = data_handler.requirements() +class DataHandlerSectorGrid(DataHandlerSingleGridColumn): + # _requirements1 = remove_items(inspect.getfullargspec(DataHandlerWRF).args, ["self", "station"]) + # _requirements2 = remove_items(inspect.getfullargspec(DataHandlerSingleGridColumn).args, ["self", "station"]) + _requirements = DataHandlerWRF.requirements() + + def __init__(self, *args, radius=None, sectors=None, sector_dim_name=None, **kwargs): + if radius is None: + radius = 100 # km + self.radius = radius + if sectors is None: + sectors = WindSector.DEFAULT_WIND_SECTORS + self.sectors = sectors + if sector_dim_name is None: + sector_dim_name = 'wind_sector' + self.sector_dim_name = sector_dim_name + self.windsector = WindSector(wind_sectors=self.sectors) + self._added_vars = [] + super().__init__(*args, **kwargs) + + def extract_data_from_loader(self, loader): + wind_dir_name = self._get_wind_dir_var_name(loader) + full_data = loader.data.isel(loader.get_nearest_icoords()).squeeze() + data = full_data[self.variables] + sec_data = self.windsector.get_sect_of_value(full_data[wind_dir_name]) + if wind_dir_name not in data: + data[wind_dir_name] = full_data[wind_dir_name] + self._added_vars.append(to_list(wind_dir_name)) + # data[self.sector_dim_name] = sec_data + return data + + def _get_wind_dir_var_name(self, loader, wdir_name3d='wdirll', wdir_name2d='wdir10ll'): + """ + Get variable name of wind direction. If no wind variable is given in self.variables, 10m wind is used as default. + If a wind variable is used in self.variables, returns the variable name (e.g. wdirll if Ull or Vll is in self.variables, + or wdir10ll if U10ll or V10ll is in self.variables. The same holds if wdirll/wspdll or wdir10ll/wspd10ll is given. + + :return: + :rtype: + """ + + if (wdir_name3d in self.variables) and not (wdir_name2d in self.variables): + var_name = wdir_name3d + elif ((wdir_name2d in self.variables) and not (wdir_name3d in self.variables)) or (not (wdir_name2d in self.variables) and not (wdir_name3d in self.variables)): + var_name = wdir_name2d + else: + raise ValueError(f"invalid wind var name passed to `_get_wind_dir_var_name'") + + assert var_name in itertools.chain(*itertools.chain(*loader.vars_to_rotate)) + return var_name + + def modify_history(self): + ws_data = self.windsector.wind_sector_edges_data() + # t, topts = self.transform(ws_data, opts=self._transformation[0]['wdir10ll'], + # transformation_dim='index') + t = self.transform(self.windsector.wind_sector_edges_data()) + return self.history + + # def set_inputs_and_targets(self): + # inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables) + helpers.to_list(self.sector_dim_name)}) + # targets = self._data.sel( + # {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim?? + # self.input_data = inputs + # self.target_data = targets + + +class DataHandlerMainSectWRF(DefaultDataHandler): + """Data handler using DataHandlerSectorGrid.""" + data_handler = DataHandlerSectorGrid + data_handler_transformation = DataHandlerSectorGrid + _requirements = data_handler.requirements() + + if __name__ == '__main__': def plot_map_proj(data, xlim=None, ylim=None, filename=None, point=None, radius=None, **kwargs):