diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index db3b044c4909ce7acf28d5a2b633e835fbc58915..85c25030a5f15ba2372b2a242acca3496dfebc4a 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -4,6 +4,7 @@ __date__ = '2019-11-07' import os from typing import Union, List, Tuple, Any, Dict +import dask.array as da import keras import xarray as xr import pickle @@ -114,8 +115,24 @@ class DataGenerator(keras.utils.Sequence): return transformation def calculate_accurate_transformation(self, method): + tmp = [] mean = None std = None + for station in self.stations: + try: + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) + chunks = (1, 100, data.data.shape[2]) + tmp.append(da.from_array(data.data.data, chunks=chunks)) + except EmptyQueryResult: + continue + tmp = da.concatenate(tmp, axis=1) + if method in ["standardise", "centre"]: + mean = da.nanmean(tmp, axis=1).compute() + mean = xr.DataArray(mean.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) + if method == "standardise": + std = da.nanstd(tmp, axis=1).compute() + std = xr.DataArray(std.flatten(), coords={"variables": sorted(self.variables)}, dims=["variables"]) return mean, std def calculate_estimated_transformation(self, method): @@ -131,7 +148,7 @@ class DataGenerator(keras.utils.Sequence): data.transform("datetime", method=method, inverse=True) except EmptyQueryResult: continue - return mean.mean("Stations") if mean.shape[1] > 0 else "hi", std.mean("Stations") if std.shape[1] > 0 else None + return mean.mean("Stations") if mean.shape[1] > 0 else None, std.mean("Stations") if std.shape[1] > 0 else None def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """