diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 54ad03d05103323c3b68fa78218c011aaa9fe426..447500203c7f92c8db5f4ece6edc195587565b6b 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -27,7 +27,7 @@ num_or_list = Union[number, List[number]] data_or_none = Union[xr.DataArray, None] -class AbstractStationPrep(): +class AbstractStationPrep(object): def __init__(self): #, path, station, statistics_per_var, transformation, **kwargs): pass @@ -77,17 +77,28 @@ class StationPrep(AbstractStationPrep): self.kwargs = kwargs # self.kwargs["overwrite_local_data"] = overwrite_local_data - self.make_samples() + # self.make_samples() + self.setup_samples() def __str__(self): return self.station[0] + def __len__(self): + assert len(self.get_X()) == len(self.get_Y()) + return len(self.get_X()) + + @property + def shape(self): + return self.data.shape, self.get_X().shape, self.get_Y().shape + def __repr__(self): - return f"StationPrep(path='{self.path}', station={self.station}, statistics_per_var={self.statistics_per_var}, " \ - f"transformation={self.transformation}, station_type='{self.station_type}', network='{self.network}', " \ + return f"StationPrep(station={self.station}, data_path='{self.path}, " \ + f"statistics_per_var={self.statistics_per_var}, " \ + f"station_type='{self.station_type}', network='{self.network}', " \ f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ f"interpolate_dim='{self.interpolate_dim}', window_history_size={self.window_history_size}, " \ - f"window_lead_time={self.window_lead_time}, **{self.kwargs})" + f"window_lead_time={self.window_lead_time}, overwrite_local_data={self.overwrite_local_data}, " \ + f"transformation={self.transformation}, **{self.kwargs})" def get_transposed_history(self) -> xr.DataArray: """Return history. @@ -116,10 +127,20 @@ class StationPrep(AbstractStationPrep): inverse=inverse ) - def make_samples(self): + def set_transformation(self, transformation: dict): + if self._transform_method is not None: + self.call_transform(inverse=True) + self.transformation = self.setup_transformation(transformation) + self.call_transform() + self.make_samples() + + def setup_samples(self): self.load_data() if self.transformation is not None: self.call_transform() + self.make_samples() + + def make_samples(self): self.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim) self.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) self.make_observation(self.target_dim, self.target_var, self.interpolate_dim) @@ -433,20 +454,12 @@ class StationPrep(AbstractStationPrep): data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data - def setup_transformation(self, transformation: Dict): + @staticmethod + def setup_transformation(transformation: Dict): """ Set up transformation by extracting all relevant information. - Extract all information from transformation dictionary. Possible keys are scope. method, mean, and std. Scope - can either be station or data. Station scope means, that data transformation is performed for each station - independently (somehow like batch normalisation), whereas data scope means a transformation applied on the - entire data set. - - * If using data scope, mean and standard deviation (each only if required by transformation method) can either - be calculated accurate or as an estimate (faster implementation). This must be set in dictionary either - as "mean": "accurate" or "mean": "estimate". In both cases, the required statistics are calculated and saved. - After this calculations, the mean key is overwritten by the actual values to use. - * If using station scope, no additional information is required. + Extract all information from transformation dictionary. Possible keys are method, mean, std, min, max. * If a transformation should be applied on base of existing values, these need to be provided in the respective keys "mean" and "std" (again only if required for given method). @@ -460,26 +473,12 @@ class StationPrep(AbstractStationPrep): raise TypeError(f"`transformation' must be either `None' or dict like e.g. `{{'method': 'standardise'}}," f" but transformation is of type {type(transformation)}.") transformation = transformation.copy() - scope = transformation.get("scope", "station") - # method = transformation.get("method", "standardise") method = transformation.get("method", None) mean = transformation.get("mean", None) std = transformation.get("std", None) max_val = transformation.get("max", None) min_val = transformation.get("min", None) - # if scope == "data": - # if isinstance(mean, str): - # if mean == "accurate": - # mean, std = self.calculate_accurate_transformation(method) - # elif mean == "estimate": - # mean, std = self.calculate_estimated_transformation(method) - # else: - # raise ValueError(f"given mean attribute must either be equal to strings 'accurate' or 'estimate' or" - # f"be an array with already calculated means. Given was: {mean}") - # if scope == "station": - # mean, std = None, None - # else: - # raise ValueError(f"Scope argument can either be 'station' or 'data'. Given was: {scope}") + transformation["method"] = method transformation["mean"] = mean transformation["std"] = std @@ -1142,6 +1141,9 @@ if __name__ == "__main__": network='UBA', sampling='daily', target_dim='variables', target_var='o3', interpolate_dim='datetime', window_history_size=7, window_lead_time=3, transformation={'method': 'standardise'}) + sp.set_transformation({'method': 'standardise', 'mean': sp.mean+2, 'std': sp.std+1}) sp.get_X() sp.get_Y() + print(len(sp)) + print(sp.shape) print(sp)