diff --git a/mlair/data_handler/input_bootstraps.py b/mlair/data_handler/input_bootstraps.py index b8ad614f2317e804d415b23308df760f4dd8da7f..c33e90caf74389a686c3955c4072cc89d41349b6 100644 --- a/mlair/data_handler/input_bootstraps.py +++ b/mlair/data_handler/input_bootstraps.py @@ -75,6 +75,14 @@ class BootstrapIterator(Iterator): else: return self._method.apply(data) + def _prepare_data_for_next(self): + index_dimension_collection = self._collection[self._position] + nboot = self._data.number_of_bootstraps + _X, _Y = self._data.data.get_data(as_numpy=False) + _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) + _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) + return _X, _Y, index_dimension_collection + class BootstrapIteratorSingleInput(BootstrapIterator): _position: int = None @@ -85,11 +93,12 @@ class BootstrapIteratorSingleInput(BootstrapIterator): def __next__(self): """Return next element or stop iteration.""" try: - index, dimension = self._collection[self._position] - nboot = self._data.number_of_bootstraps - _X, _Y = self._data.data.get_data(as_numpy=False) - _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) - _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) + _X, _Y, (index, dimension) = self._prepare_data_for_next() + # index, dimension = self._collection[self._position] + # nboot = self._data.number_of_bootstraps + # _X, _Y = self._data.data.get_data(as_numpy=False) + # _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) + # _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) single_variable = _X[index].sel({self._dimension: [dimension]}) bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, @@ -117,11 +126,7 @@ class BootstrapIteratorVariable(BootstrapIterator): def __next__(self): """Return next element or stop iteration.""" try: - dimension = self._collection[self._position] - nboot = self._data.number_of_bootstraps - _X, _Y = self._data.data.get_data(as_numpy=False) - _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) - _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) + _X, _Y, dimension = self._prepare_data_for_next() for index in range(len(_X)): if dimension in _X[index].coords[self._dimension]: single_variable = _X[index].sel({self._dimension: [dimension]}) @@ -150,11 +155,12 @@ class BootstrapIteratorBranch(BootstrapIterator): def __next__(self): try: - index = self._collection[self._position] - nboot = self._data.number_of_bootstraps - _X, _Y = self._data.data.get_data(as_numpy=False) - _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) - _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) + _X, _Y, index = self._prepare_data_for_next() + # index = self._collection[self._position] + # nboot = self._data.number_of_bootstraps + # _X, _Y = self._data.data.get_data(as_numpy=False) + # _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) + # _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) for dimension in _X[index].coords[self._dimension].values: single_variable = _X[index].sel({self._dimension: [dimension]}) bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) @@ -172,6 +178,68 @@ class BootstrapIteratorBranch(BootstrapIterator): return list(range(len(data.get_X(as_numpy=False)))) +class BootstrapIteratorVariableSets(BootstrapIterator): + _variable_set_splitters: list = ['Sect', 'SectLeft', 'SectRight',] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + #self.variable_set_splitters = ['Sect', 'SectLeft', 'SectRight'] + + def __next__(self): + try: + _X, _Y, (_index, _dimension) = self._prepare_data_for_next() + + for index, dimensions in self._collection: + print(index, dimensions) + for dimension in dimensions: + print(index, dimension) + single_variable = _X[index].sel({self._dimension: [dimension]}) + bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + dims=single_variable.dims) + _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) + + # for dimension in _X[index].coords[self._dimension].values: + # single_variable = _X[index].sel({self._dimension: [dimension]}) + # bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + # bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + # dims=single_variable.dims) + # _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) + self._position += 1 + except IndexError: + StopIteration() + _X, _Y = self._to_numpy(_X), self._to_numpy(_Y) + return self._reshape(_X), self._reshape(_Y), (index, dimension) + # return self._reshape(_X), self._reshape(_Y), (None, index) + + @classmethod + def create_collection(cls, data, dim): + # l = set() + # for i, x in enumerate(data.get_X(as_numpy=False)): + # l.update(x.indexes[dim].to_list()) + # # l.update(['O3Sect', 'O3SectLeft', 'O3SectRight']) # ToDo Remove : just for testing + # return [[var for var in to_list(l) if var.endswith(collection_name)] for collection_name in cls._variable_set_splitters] + + l = [] + for i, x in enumerate(data.get_X(as_numpy=False)): + l.append(x.indexes[dim].to_list()) + l[0] = l[0] + ['o3Sect', 'o3SectLeft', 'o3SectRight', 'no2Sect', 'no2SectLeft', 'no2SectRight'] + + res = [[var for var in l[i] if var.endswith(collection_name)] for collection_name in cls._variable_set_splitters] + res = [(i, dimensions) for i, _ in enumerate(data.get_X(as_numpy=False)) for dimensions in res] + return res + # return list(chain(*res)) + # [[(0, 'o3'), (0, 'relhum'), (0, 'temp'), (0, 'u'), (0, 'v'), (0, 'no'), (0, 'no2'), (0, 'cloudcover'), + # (0, 'pblheight')]] + + + + # l = [] + # for i, x in enumerate(data.get_X(as_numpy=False)): + # l.append(list(map(lambda y: (i, y), x.indexes[dim]))) + # return list(chain(*l)) + + class ShuffleBootstraps: @staticmethod @@ -225,10 +293,12 @@ class Bootstraps(Iterable): self.bootstrap_method = {"shuffle": ShuffleBootstraps(), "zero_mean": MeanBootstraps(mean=0)}.get( bootstrap_method) # todo adjust number of bootstraps if mean bootstrapping + self.bootstrap_type = bootstrap_type self.BootstrapIterator = {"singleinput": BootstrapIteratorSingleInput, "branch": BootstrapIteratorBranch, - "variable": BootstrapIteratorVariable}.get(bootstrap_type, - BootstrapIteratorSingleInput) + "variable": BootstrapIteratorVariable, + "group_of_variables": BootstrapIteratorVariableSets, + }.get(bootstrap_type, BootstrapIteratorSingleInput) def __iter__(self): return self.BootstrapIterator(self, self.bootstrap_method) @@ -236,6 +306,11 @@ class Bootstraps(Iterable): def __len__(self): return len(self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension)) + def __repr__(self): + return f"Bootstraps(data={self.data}, number_of_bootstraps={self.number_of_bootstraps}, " \ + f"bootstrap_dimension='{self.bootstrap_dimension}', bootstrap_type='{self.bootstrap_type}', " \ + f"bootstrap_method='{self.bootstrap_method}')" + def bootstraps(self): return self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension)