Skip to content
Snippets Groups Projects
Commit 690fca57 authored by v.gramlich1's avatar v.gramlich1
Browse files

Trying to make bins and rates_cap more flexible, inserting default values.

parent faf3c2c6
No related branches found
No related tags found
1 merge request!302Draft: Resolve "Class-based Oversampling technique"
Pipeline #70761 failed
......@@ -55,9 +55,8 @@ DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA
"pm10": "", "so2": ""}
DEFAULT_USE_MULTIPROCESSING = True
DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False
DEFAULT_BINS = 10
DEFAULT_RATES_CAP = 20
DEFAULT_OVERSAMPLING_BINS = 10
DEFAULT_OVERSAMPLING_RATES_CAP = 20
def get_defaults():
......
......@@ -166,6 +166,38 @@ class DefaultDataHandler(AbstractDataHandler):
def apply_transformation(self, data, base="target", dim=0, inverse=False):
return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse)
def apply_oversampling(self, bin_edges, oversampling_rates):
self._load()
if (self._X is None) or (self._Y is None):
logging.debug(f"{str(self.id_class)} has no data for X or Y, skip multiply extremes")
return
Y = self._Y
X = self._X
for i_bin in range(len(bin_edges)-1):
bin_start = bin_edges[i_bin]
if i_bin == len(bin_edges) - 1:
bin_end = bin_edges[i_bin+1]+1
else:
bin_end = bin_edges[i_bin + 1]
rate = oversampling_rates[i_bin]
# extract extremes based on occurrence in labels
other_dims = remove_items(list(Y.dims), self.time_dim)
extreme_idx = xr.concat([(Y >= bin_start).any(dim=other_dims[0]),
(Y < bin_end).any(dim=other_dims[0])],
dim=other_dims[0]).all(dim=other_dims[0])
extremes_X = list(map(lambda x: x.sel(**{self.time_dim: extreme_idx}), X))
self._add_timedelta(extremes_X, dim, timedelta)
# extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X))
extremes_Y = Y.sel(**{dim: extreme_idx})
extremes_Y.coords[dim].values += np.timedelta64(*timedelta)
self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM):
"""
......
......@@ -19,7 +19,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \
DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG
DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_OVERSAMPLING_BINS, DEFAULT_OVERSAMPLING_RATES_CAP
from mlair.data_handler import DefaultDataHandler
from mlair.run_modules.run_environment import RunEnvironment
from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
......@@ -183,6 +183,9 @@ class ExperimentSetup(RunEnvironment):
:param use_multiprocessing: Enable parallel preprocessing (postprocessing not implemented yet) by setting this
parameter to `True` (default). If set to `False` the computation is performed in an serial approach.
Multiprocessing is disabled when running in debug mode and cannot be switched on.
:param oversampling_bins: Sets the number of classes in which the training data is split. The training samples are then
oversampled according to the frequency of the different classes.
:param oversampling_rates_cap: Sets the maximum oversampling rate that is applied to a class
"""
......@@ -216,7 +219,7 @@ class ExperimentSetup(RunEnvironment):
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
bins=None, rates_cap=None, **kwargs):
oversampling_bins=None, oversampling_rates_cap=None, **kwargs):
# create run framework
super().__init__()
......@@ -362,8 +365,8 @@ class ExperimentSetup(RunEnvironment):
self._set_param("model_class", model, VanillaModel)
# set params for oversampling
self._set_param("bins", bins, default=DEFAULT_BINS)
self._set_param("rates_cap", rates_cap, default=DEFAULT_RATES_CAP)
self._set_param("oversampling_bins", oversampling_bins, default=DEFAULT_OVERSAMPLING_BINS)
self._set_param("oversampling_rates_cap", oversampling_rates_cap, default=DEFAULT_OVERSAMPLING_RATES_CAP)
# set remaining kwargs
if len(kwargs) > 0:
......
......@@ -74,8 +74,8 @@ class PreProcessing(RunEnvironment):
def apply_oversampling(self):
#if request for oversampling=True/False
data = self.data_store.get('data_collection', 'train')
bins = self.data_store.get_default('bins')
rates_cap = self.data_store.get_default('rates_cap')
bins = self.data_store.get('oversampling_bins')
rates_cap = self.data_store.get('oversampling_rates_cap')
histogram = np.array(bins)
#get min and max of the whole data
total_min = 0
......@@ -83,9 +83,10 @@ class PreProcessing(RunEnvironment):
for station in data:
total_min = np.minimum(np.amin(station.get_Y(as_numpy=True)), total_min)
total_max = np.maximum(np.amax(station.get_Y(as_numpy=True)), total_max)
bin_edges = []
for station in data:
# Create histogram for each station
hist, _ = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min,total_max))
hist, bin_edges = np.histogram(station.get_Y(as_numpy=True), bins=bins, range=(total_min,total_max))
# Add up histograms
histogram = histogram + hist
# Scale down to most frequent class=1
......@@ -94,8 +95,11 @@ class PreProcessing(RunEnvironment):
oversampling_rates = 1 / histogram
oversampling_rates_capped = np.minimum(oversampling_rates, rates_cap)
# Add to datastore
self.data_store.set('oversampling_rates', oversampling_rates, 'training')
self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'training')
self.data_store.set('oversampling_rates', oversampling_rates, 'train')
self.data_store.set('oversampling_rates_capped', oversampling_rates_capped, 'train')
self.data_store.set('oversampling_bin_edges', bin_edges)
for station in data:
station.apply_oversampling(bin_edges, oversampling_rates_capped)
def report_pre_processing(self):
"""Log some metrics on data and create latex report."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment