diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 785aab88992e84a84ab4144040597922a48e5134..2b817f531a39ffdf0fb41d57e604a10734dba2fd 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -56,6 +56,9 @@ DEFAULT_DATA_ORIGIN = {"cloudcover": "REA", "humidity": "REA", "pblheight": "REA DEFAULT_USE_MULTIPROCESSING = True DEFAULT_USE_MULTIPROCESSING_ON_DEBUG = False +DEFAULT_BINS = 10 +DEFAULT_RATES_CAP = 20 + def get_defaults(): """Return all default parameters set in defaults.py""" diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 24fedaa82615f93941ee737f13981e0c334259a9..e28eb76db0232f72fe4d6548e29a966bd4915951 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -215,7 +215,8 @@ class ExperimentSetup(RunEnvironment): create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, data_origin: Dict = None, competitors: list = None, competitor_path: str = None, - use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None, **kwargs): + use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None, + bins=None, rates_cap=None, **kwargs): # create run framework super().__init__() @@ -360,6 +361,10 @@ class ExperimentSetup(RunEnvironment): # set model architecture class self._set_param("model_class", model, VanillaModel) + # set params for oversampling + self._set_param("bins", bins, default=DEFAULT_BINS) + self._set_param("rates_cap", rates_cap, default=DEFAULT_RATES_CAP) + # set remaining kwargs if len(kwargs) > 0: for k, v in kwargs.items(): diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 05bd61d941914c0e6bd332ce886ce2f62e293d8e..92cdac47939cee7b4f66b2c302f6c0adcb205422 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -71,9 +71,11 @@ class PreProcessing(RunEnvironment): self.report_pre_processing() self.prepare_competitors() - def apply_oversampling(self, bins=10, rates_cap=20): + def apply_oversampling(self): #if request for oversampling=True/False data = self.data_store.get('data_collection', 'train') + bins = self.data_store.get_default('bins') + rates_cap = self.data_store.get_default('rates_cap') histogram = np.array(bins) #get min and max of the whole data total_min = 0