diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py index cb5aa5db0f29cf51d32ed54e810fa9b363d80cc6..dc14146417bd4a6bf30ccbb79537e7920313b077 100644 --- a/src/data_handling/__init__.py +++ b/src/data_handling/__init__.py @@ -13,3 +13,6 @@ from .bootstraps import BootStraps from .data_preparation_join import DataPrepJoin from .data_generator import DataGenerator from .data_distributor import Distributor +from .iterator import KerasIterator, DataCollection +from .advanced_data_handling import DataPreparation +from .data_preparation import StationPrep \ No newline at end of file diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py index e36e0c75fc9107431a69482d46755acbdf5334bd..d4c4e363dfd1d27ad7ecb2ef34a619b61c20e9fb 100644 --- a/src/data_handling/advanced_data_handling.py +++ b/src/data_handling/advanced_data_handling.py @@ -4,6 +4,7 @@ __date__ = '2020-07-08' from src.helpers import to_list, remove_items +from src.data_handling.data_preparation import StationPrep import numpy as np import xarray as xr import pickle @@ -46,8 +47,8 @@ class DummyDataSingleStation: # pragma: no cover class DataPreparation: - def __init__(self, id_class, interpolate_dim: str, store_path, neighbors=None, min_length=0, - extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False,): + def __init__(self, id_class, interpolate_dim: str, data_path, neighbors=None, min_length=0, + extreme_values: num_or_list = None,extremes_on_right_tail_only: bool = False): self.id_class = id_class self.neighbors = to_list(neighbors) if neighbors is not None else [] self.interpolate_dim = interpolate_dim @@ -56,7 +57,7 @@ class DataPreparation: self._Y = None self._X_extreme = None self._Y_extreme = None - self._save_file = os.path.join(store_path, f"data_preparation_{str(self.id_class)}.pickle") + self._save_file = os.path.join(data_path, f"data_preparation_{str(self.id_class)}.pickle") self._collection = [] self._create_collection() self.harmonise_X() @@ -119,17 +120,17 @@ class DataPreparation: def _to_numpy(d): return list(map(lambda x: np.copy(x), d)) - def get_X(self, upsamling=False, as_numpy=True): + def get_X(self, upsampling=False, as_numpy=True): no_data = (self._X is None) self._load() if no_data is True else None - X = self._X if upsamling is False else self._X_extreme + X = self._X if upsampling is False else self._X_extreme self._reset_data() if no_data is True else None return self._to_numpy(X) if as_numpy is True else X - def get_Y(self, upsamling=False, as_numpy=True): + def get_Y(self, upsampling=False, as_numpy=True): no_data = (self._Y is None) self._load() if no_data is True else None - Y = self._Y if upsamling is False else self._Y_extreme + Y = self._Y if upsampling is False else self._Y_extreme self._reset_data() if no_data is True else None return self._to_numpy([Y]) if as_numpy is True else Y @@ -250,6 +251,41 @@ def create_data_prep(): data_prep.append(DataPreparation(neighbor2, interpolate_dim, path, neighbors=[neighbor1, central_station])) return data_prep + +class AbstractDataClass: + + def __init__(self): + self._requires = [] + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + @property + def requirements(self): + return self._requires + + @requirements.setter + def requirements(self, value): + self._requires = value + + +class CustomDataClass(AbstractDataClass): + + def __init__(self): + import inspect + super().__init__() + self.sp_keys = remove_items(inspect.getfullargspec(StationPrep).args, ["self", "station"]) + self.dp_keys = remove_items(inspect.getfullargspec(DataPreparation).args, ["self", "id_class"]) + self.requirements = self.sp_keys + self.dp_keys + + def __call__(self, station, **kwargs): + sp_keys = {k: kwargs[k] for k in self.sp_keys if k in kwargs} + sp_keys["station"] = station + sp = StationPrep(**sp_keys) + dp_args = {k: kwargs[k] for k in self.dp_keys if k in kwargs} + return DataPreparation(sp, **dp_args) + + if __name__ == "__main__": from src.data_handling.data_preparation import StationPrep from src.data_handling.iterator import KerasIterator, DataCollection @@ -258,6 +294,6 @@ if __name__ == "__main__": for data in data_collection: print(data) path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata", "keras") - keras_it = KerasIterator(data_collection, 100, path) + keras_it = KerasIterator(data_collection, 100, path, upsampling=True) keras_it[2] diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index d5933f193018efb1529db2c026981e8c4d7936d2..dadda2c58979ddb2678d366470c3b6d3f0584ee4 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -68,8 +68,8 @@ class AbstractStationPrep(): class StationPrep(AbstractStationPrep): - def __init__(self, path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var, - interpolate_dim, window_history_size, window_lead_time, **kwargs): + def __init__(self, data_path, station, statistics_per_var, transformation, station_type, network, sampling, target_dim, target_var, + interpolate_dim, window_history_size, window_lead_time, overwrite_local_data: bool = False, **kwargs): super().__init__() # path, station, statistics_per_var, transformation, **kwargs) self.station_type = station_type self.network = network @@ -80,12 +80,10 @@ class StationPrep(AbstractStationPrep): self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.path = os.path.abspath(path) + self.path = os.path.abspath(data_path) self.station = helpers.to_list(station) self.statistics_per_var = statistics_per_var # self.target_dim = 'variable' - self.transformation = self.setup_transformation(transformation) - self.kwargs = kwargs # internal self.data = None @@ -95,17 +93,15 @@ class StationPrep(AbstractStationPrep): self.label = None self.observation = None - def __str__(self): - return self.station[0] + self.transformation = self.setup_transformation(transformation) + self.kwargs = kwargs + self.kwargs["overwrite_local_data"] = overwrite_local_data - def load_data(self): - try: - self.read_data_from_disk() - except FileNotFoundError: - self.download_data() - self.load_data() self.make_samples() + def __str__(self): + return self.station[0] + def __repr__(self): return f"StationPrep(path='{self.path}', station={self.station}, statistics_per_var={self.statistics_per_var}, " \ f"transformation={self.transformation}, station_type='{self.station_type}', network='{self.network}', " \ diff --git a/src/data_handling/iterator.py b/src/data_handling/iterator.py index 14d71a9afc23d3a0d80bacf60bbaa928fb34407a..d2ef9eb8df6373934e30ef9ca98c5de3fefed6c9 100644 --- a/src/data_handling/iterator.py +++ b/src/data_handling/iterator.py @@ -33,23 +33,37 @@ class StandardIterator(Iterator): class DataCollection(Iterable): - def __init__(self, collection: list): + def __init__(self, collection: list = None): + if collection is None: + collection = [] assert isinstance(collection, list) self._collection = collection + def __len__(self): + return len(self._collection) + def __iter__(self) -> Iterator: return StandardIterator(self._collection) + def __getitem__(self, index): + return self._collection[index] + + def add(self, element): + self._collection.append(element) + class KerasIterator(keras.utils.Sequence): - def __init__(self, collection: DataCollection, batch_size: int, path: str, shuffle: bool = False): + def __init__(self, collection: DataCollection, batch_size: int, batch_path: str, shuffle_batches: bool = False, + model=None, upsampling=False): self._collection = collection - self._path = os.path.join(path, "%i.pickle") + self._path = os.path.join(batch_path, "%i.pickle") self.batch_size = batch_size - self.shuffle = shuffle + self.model = model + self.shuffle = shuffle_batches + self.upsampling = upsampling self.indexes: list = [] - self._cleanup_path(path) + self._cleanup_path(batch_path) self._prepare_batches() def __len__(self) -> int: @@ -59,6 +73,19 @@ class KerasIterator(keras.utils.Sequence): """Get batch for given index.""" return self.__data_generation(self.indexes[index]) + def _get_model_rank(self): + if self.model is not None: + mod_out = self.model.output_shape + if isinstance(mod_out, tuple): # only one output branch: (None, ahead) + mod_rank = 1 + elif isinstance(mod_out, list): # multiple output branches, e.g.: [(None, ahead), (None, ahead)] + mod_rank = len(mod_out) + else: # pragma: no cover + raise TypeError("model output shape must either be tuple or list.") + return mod_rank + else: # no model provided, assume to use single output + return 1 + def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]: """Load pickle data from disk.""" file = self._path % index @@ -75,6 +102,12 @@ class KerasIterator(keras.utils.Sequence): """Get batch according to batch size from data list.""" return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list)) + def _permute_data(self, X, Y): + p = np.random.permutation(len(X[0])) # equiv to .shape[0] + X = list(map(lambda x: x[p], X)) + Y = list(map(lambda x: x[p], Y)) + return X, Y + def _prepare_batches(self) -> None: """ Prepare all batches as locally stored files. @@ -86,8 +119,12 @@ class KerasIterator(keras.utils.Sequence): """ index = 0 remaining = None + mod_rank = self._get_model_rank() for data in self._collection: - X, Y = data.get_X(), data.get_Y() + X = data.get_X(upsampling=self.upsampling) + Y = [data.get_Y(upsampling=self.upsampling)[0] for _ in range(mod_rank)] + if self.upsampling: + X, Y = self._permute_data(X, Y) if remaining is not None: X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1]) length = X[0].shape[0] diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index dab2e168c5a9f87d4aee42fc94489fd0fa67772a..6b3b9972bc0af4c968f2831963cc18446ff09162 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -351,7 +351,7 @@ class MyLittleModel(AbstractModelClass): # settings self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.channels = channels + self.channels = channels[0] self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) self.activation = keras.layers.PReLU @@ -387,7 +387,7 @@ class MyLittleModel(AbstractModelClass): x_in = self.activation()(x_in) x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in) out_main = self.activation()(x_in) - self.model = keras.Model(inputs=x_input, outputs=[out_main]) + self.model = keras.Model(inputs=[x_input], outputs=[out_main]) def set_compile_options(self): self.initial_lr = 1e-2 @@ -423,7 +423,7 @@ class MyBranchedModel(AbstractModelClass): # settings self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.channels = channels + self.channels = channels[0] self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) self.activation = keras.layers.PReLU @@ -493,7 +493,7 @@ class MyTowerModel(AbstractModelClass): # settings self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.channels = channels + self.channels = channels[0] self.dropout_rate = 1e-2 self.regularizer = keras.regularizers.l2(0.1) self.initial_lr = 1e-2 @@ -605,7 +605,7 @@ class MyPaperModel(AbstractModelClass): # settings self.window_history_size = window_history_size self.window_lead_time = window_lead_time - self.channels = channels + self.channels = channels[0] self.dropout_rate = .3 self.regularizer = keras.regularizers.l2(0.001) self.initial_lr = 1e-3 diff --git a/src/run.py b/src/run.py index 7e262dd769204077697b7df3f3fbaedb4c012257..4033d52303035ede583529169e93548ab7a205e1 100644 --- a/src/run.py +++ b/src/run.py @@ -39,5 +39,5 @@ def run(stations=None, if __name__ == "__main__": - - run() + from src.data_handling.advanced_data_handling import CustomDataClass + run(data_preparation=CustomDataClass, statistics_per_var={'o3': 'dma8eu'}, transformation={}) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 1d375c32be06b583abbfb06a20ea482e6775b232..3e471dda7934fc53c990ede6e459c41f3ef6229b 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -19,6 +19,7 @@ from src.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, D DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST from src.data_handling import DataPrepJoin +from src.data_handling.advanced_data_handling import CustomDataClass from src.run_modules.run_environment import RunEnvironment from src.model_modules.model_class import MyLittleModel as VanillaModel @@ -228,8 +229,8 @@ class ExperimentSetup(RunEnvironment): create_new_model = None, bootstrap_path=None, permute_data_on_training = None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, - create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None, - batch_size=None, epochs=None, data_preparation=None): + create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, + hpc_hosts=None, model=None, batch_size=None, epochs=None, data_preparation=None): # create run framework super().__init__() @@ -265,6 +266,9 @@ class ExperimentSetup(RunEnvironment): logging.info(f"Experiment path is: {experiment_path}") path_config.check_path_and_create(self.data_store.get("experiment_path")) + # batch path (temporary) + self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data")) + # set model path self._set_param("model_path", None, os.path.join(experiment_path, "model")) path_config.check_path_and_create(self.data_store.get("model_path")) @@ -297,7 +301,8 @@ class ExperimentSetup(RunEnvironment): self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") - self._set_param("data_preparation", data_preparation, default=DataPrepJoin) + self._set_param("data_preparation", data_preparation(), default=CustomDataClass()) + assert isinstance(getattr(self.data_store.get("data_preparation"), "requirements"), property) is False # target self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index f9683b953d85bacf6e452e0a1922e85dfe946cd1..dc537eb1cd3e5cf04fbdddee3017e4ace7f7bfca 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -90,7 +90,7 @@ class ModelSetup(RunEnvironment): def _set_channels(self): """Set channels as number of variables of train generator.""" - channels = self.data_store.get("generator", "train")[0][0].shape[-1] + channels = list(map(lambda x: x[0].shape[-1], self.data_store.get("data_collection", "train")[0].get_X())) self.data_store.set("channels", channels, self.scope) def compile_model(self): diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index db7fff2ab9e385ce769f86ef95d1565ea783cc95..72493c1fbad42a7aa9fec1e32292c0727a7dfb38 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -11,6 +11,8 @@ import numpy as np import pandas as pd from src.data_handling import DataGenerator +from src.data_handling import DataCollection, DataPreparation, StationPrep +from src.data_handling.advanced_data_handling import CustomDataClass from src.helpers import TimeTracking from src.configuration import path_config from src.helpers.join import EmptyQueryResult @@ -59,10 +61,9 @@ class PreProcessing(RunEnvironment): self._run() def _run(self): - args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="preprocessing") - kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="preprocessing") stations = self.data_store.get("stations") - valid_stations = self.check_valid_stations(args, kwargs, stations, load_tmp=False, save_tmp=False, name="all") + data_preparation = self.data_store.get("data_preparation") + _, valid_stations = self.validate_station(data_preparation, stations, "preprocessing", overwrite_local_data=True) self.data_store.set("stations", valid_stations) self.split_train_val_test() self.report_pre_processing() @@ -70,16 +71,14 @@ class PreProcessing(RunEnvironment): def report_pre_processing(self): """Log some metrics on data and create latex report.""" logging.debug(20 * '##') - n_train = len(self.data_store.get('generator', 'train')) - n_val = len(self.data_store.get('generator', 'val')) - n_test = len(self.data_store.get('generator', 'test')) + n_train = len(self.data_store.get('data_collection', 'train')) + n_val = len(self.data_store.get('data_collection', 'val')) + n_test = len(self.data_store.get('data_collection', 'test')) n_total = n_train + n_val + n_test logging.debug(f"Number of all stations: {n_total}") logging.debug(f"Number of training stations: {n_train}") logging.debug(f"Number of val stations: {n_val}") logging.debug(f"Number of test stations: {n_test}") - logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}" - f"{self.data_store.get('generator', 'test')[0][1].shape}") self.create_latex_report() def create_latex_report(self): @@ -121,11 +120,12 @@ class PreProcessing(RunEnvironment): set_names = ["train", "val", "test"] df = pd.DataFrame(columns=meta_data + set_names) for set_name in set_names: - data: DataGenerator = self.data_store.get("generator", set_name) - for station in data.stations: - df.loc[station, set_name] = data.get_data_generator(station).get_transposed_label().shape[0] - if df.loc[station, meta_data].isnull().any(): - df.loc[station, meta_data] = data.get_data_generator(station).meta.loc[meta_data].values.flatten() + data = self.data_store.get("data_collection", set_name) + for station in data: + station_name = str(station.id_class) + df.loc[station_name, set_name] = station.get_Y()[0].shape[0] + if df.loc[station_name, meta_data].isnull().any(): + df.loc[station_name, meta_data] = station.id_class.meta.loc[meta_data].values.flatten() df.loc["# Samples", set_name] = df.loc[:, set_name].sum() df.loc["# Stations", set_name] = df.loc[:, set_name].count() df[meta_round] = df[meta_round].astype(float).round(precision) @@ -147,7 +147,7 @@ class PreProcessing(RunEnvironment): Split data into subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate - generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs + data_collection). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs always to be executed at first, to set a proper transformation. """ fraction_of_training = self.data_store.get("fraction_of_training") @@ -159,7 +159,7 @@ class PreProcessing(RunEnvironment): raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset" f"order was: {subset_names}.") for (ind, scope) in zip([train_index, val_index, test_index, train_val_index], subset_names): - self.create_set_split(ind, scope) + self.create_set_split_new(ind, scope) @staticmethod def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]: @@ -183,13 +183,27 @@ class PreProcessing(RunEnvironment): train_val_index = slice(0, pos_test_split) return train_index, val_index, test_index, train_val_index + def create_set_split_new(self, index_list: slice, set_name: str) -> None: + # get set stations + stations = self.data_store.get("stations", scope=set_name) + if self.data_store.get("use_all_stations_on_all_data_sets"): + set_stations = stations + else: + set_stations = stations[index_list] + logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") + # create set data_collection and store + data_preparation = self.data_store.get("data_preparation") + collection, valid_stations = self.validate_station(data_preparation, set_stations, set_name) + self.data_store.set("stations", valid_stations, scope=set_name) + self.data_store.set("data_collection", collection, scope=set_name) + def create_set_split(self, index_list: slice, set_name: str) -> None: """ Create subsets and store in data store. - Create the subset for given split index and stores the DataGenerator with given set name in data store as - `generator`. Check for all valid stations using the default (kw)args for given scope and create the - DataGenerator for all valid stations. Also set all transformation information, if subset is training set. Make + Create the subset for given split index and stores the data_collection with given set name in data store as + `data_collection`. Check for all valid stations using the default (kw)args for given scope and create the + data_collection for all valid stations. Also set all transformation information, if subset is training set. Make sure, that the train set is executed first, and all other subsets afterwards. :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True, @@ -207,13 +221,18 @@ class PreProcessing(RunEnvironment): # validate set set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False, name=set_name) self.data_store.set("stations", set_stations, scope=set_name) - # create set generator and store + # create set data_collection and store set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name) - data_set = DataGenerator(**set_args, **kwargs) - self.data_store.set("generator", data_set, scope=set_name) - # extract transformation from train set - if set_name == "train": - self.data_store.set("transformation", data_set.transformation) + data_prep_kwargs = self.data_store.create_args_dict(["interpolate_dim", "data_path", "min_length", "extreme_values", "extremes_on_right_tail_only"], scope=set_name) + collection = DataCollection() + for station in set_stations: + args["station"] = station + + def f(sp_args, sp_kwargs, dp_kwargs): + DataPreparation(StationPrep(**sp_args, **sp_kwargs), **dp_kwargs) + + collection.add(f(**set_args, **kwargs, **data_prep_kwargs)) + self.data_store.set("data_collection", collection, scope=set_name) @staticmethod def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True, @@ -257,3 +276,36 @@ class PreProcessing(RunEnvironment): logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/" f"{len(all_stations)} valid stations.") return valid_stations + + def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False): + """ + Check if all given stations in `all_stations` are valid. + + Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the + loading time are logged in debug mode. + + :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`, + `variables`, `interpolate_dim`, `target_dim`, `target_var`). + :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`, + `window_lead_time`). + :param all_stations: All stations to check. + :param name: name to display in the logging info message + + :return: Corrected list containing only valid station IDs. + """ + t_outer = TimeTracking() + logging.info(f"check valid stations started{' (%s)' % set_name if set_name is not None else 'all'}") + collection = DataCollection() + valid_stations = [] + kwargs = self.data_store.create_args_dict(data_preparation.requirements, scope=set_name) + for station in set_stations: + try: + dp = data_preparation(station, **kwargs) + collection.add(dp) + valid_stations.append(station) + except (AttributeError, EmptyQueryResult): + continue + logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" + f"{len(set_stations)} valid stations.") + return collection, valid_stations + diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 1a0d7beb1ec37bb5e59a4129da58572d79a73636..a92fd56fda5599489992b1bccaca3a715dd622d7 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -11,7 +11,7 @@ from typing import Union import keras from keras.callbacks import Callback, History -from src.data_handling import Distributor +from src.data_handling import KerasIterator from src.model_modules.keras_extensions import CallbackHandler from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from src.run_modules.run_environment import RunEnvironment @@ -64,9 +64,9 @@ class Training(RunEnvironment): """Set up and run training.""" super().__init__() self.model: keras.Model = self.data_store.get("model", "model") - self.train_set: Union[Distributor, None] = None - self.val_set: Union[Distributor, None] = None - self.test_set: Union[Distributor, None] = None + self.train_set: Union[KerasIterator, None] = None + self.val_set: Union[KerasIterator, None] = None + self.test_set: Union[KerasIterator, None] = None self.batch_size = self.data_store.get("batch_size") self.epochs = self.data_store.get("epochs") self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model") @@ -102,9 +102,9 @@ class Training(RunEnvironment): :param mode: name of set, should be from ["train", "val", "test"] """ - gen = self.data_store.get("generator", mode) - kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=mode) - setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs)) + collection = self.data_store.get("data_collection", mode) + kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path"], scope=mode) + setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, **kwargs)) def set_generators(self) -> None: """ @@ -128,15 +128,15 @@ class Training(RunEnvironment): """ logging.info(f"Train with {len(self.train_set)} mini batches.") logging.info(f"Train with option upsampling={self.train_set.upsampling}.") - logging.info(f"Train with option data_permutation={self.train_set.do_data_permutation}.") + logging.info(f"Train with option shuffle={self.train_set.shuffle}.") checkpoint = self.callbacks.get_checkpoint() if not os.path.exists(checkpoint.filepath) or self._create_new_model: - history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), + history = self.model.fit_generator(generator=self.train_set, steps_per_epoch=len(self.train_set), epochs=self.epochs, verbose=2, - validation_data=self.val_set.distribute_on_batches(), + validation_data=self.val_set, validation_steps=len(self.val_set), callbacks=self.callbacks.get_callbacks(as_dict=False)) else: @@ -146,11 +146,11 @@ class Training(RunEnvironment): self.model = keras.models.load_model(checkpoint.filepath) hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 - _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), + _ = self.model.fit_generator(generator=self.train_set, steps_per_epoch=len(self.train_set), epochs=self.epochs, verbose=2, - validation_data=self.val_set.distribute_on_batches(), + validation_data=self.val_set, validation_steps=len(self.val_set), callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch)