diff --git a/src/data_handler/station_preparation.py b/src/data_handler/station_preparation.py index da8c3ad83bc3c794e540863f6343b7337484ee7d..42d94e277415c637d4fc9a5262692a6b3150b0a7 100644 --- a/src/data_handler/station_preparation.py +++ b/src/data_handler/station_preparation.py @@ -95,13 +95,29 @@ class StationPrep(AbstractStationPrep): return self.data.shape, self.get_X().shape, self.get_Y().shape def __repr__(self): - return f"StationPrep(station={self.station}, data_path='{self.path}, " \ + return f"StationPrep(station={self.station}, data_path='{self.path}', " \ f"statistics_per_var={self.statistics_per_var}, " \ f"station_type='{self.station_type}', network='{self.network}', " \ f"sampling='{self.sampling}', target_dim='{self.target_dim}', target_var='{self.target_var}', " \ f"interpolate_dim='{self.interpolate_dim}', window_history_size={self.window_history_size}, " \ f"window_lead_time={self.window_lead_time}, overwrite_local_data={self.overwrite_local_data}, " \ - f"transformation={self.transformation}, **{self.kwargs})" + f"transformation={self._print_transformation_as_string}, **{self.kwargs})" + + @property + def _print_transformation_as_string(self): + str_name = '' + if self.transformation is None: + str_name = f'{None}' + else: + for k, v in self.transformation.items(): + if v is not None: + try: + v_pr = f"xr.DataArray.from_dict({v.to_dict()})" + except AttributeError: + v_pr = f"'{v}'" + str_name += f"'{k}':{v_pr}, " + str_name = f"{{{str_name}}}" + return str_name def get_transposed_history(self) -> xr.DataArray: """Return history. @@ -608,6 +624,8 @@ class StationPrep(AbstractStationPrep): self.check_inverse_transform_params(self.mean, self.std, self._transform_method) self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method) self._transform_method = None + # update X and Y + self.make_samples() def get_transformation_information(self, variable: str = None) -> Tuple[data_or_none, data_or_none, str]: """ @@ -641,8 +659,14 @@ if __name__ == "__main__": statistics_per_var=statistics_per_var, station_type='background', network='UBA', sampling='daily', target_dim='variables', target_var='o3', interpolate_dim='datetime', window_history_size=7, window_lead_time=3, - transformation={'method': 'standardise'}) - sp.set_transformation({'method': 'standardise', 'mean': sp.mean+2, 'std': sp.std+1}) + ) # transformation={'method': 'standardise'}) + # sp.set_transformation({'method': 'standardise', 'mean': sp.mean+2, 'std': sp.std+1}) + sp2 = StationPrep(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122', + statistics_per_var=statistics_per_var, station_type='background', + network='UBA', sampling='daily', target_dim='variables', target_var='o3', + interpolate_dim='datetime', window_history_size=7, window_lead_time=3, + transformation={'method': 'standardise'}) + sp2.transform(inverse=True) sp.get_X() sp.get_Y() print(len(sp))