diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index cfe058f2a5682a70475818916c6e160f51efc7b3..535941f8975a5c5c88c5c69278d787e6d89b9f81 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -133,7 +133,7 @@ class AbstractModelClass(ABC): the corresponding loss function. """ - def __init__(self) -> None: + def __init__(self, shape_inputs, shape_outputs) -> None: """Predefine internal attributes for model and loss.""" self.__model = None self.model_name = self.__class__.__name__ @@ -147,6 +147,8 @@ class AbstractModelClass(ABC): 'target_tensors': None } self.__compile_options = self.__allowed_compile_options + self.shape_inputs = shape_inputs + self.shape_outputs = self.__extract_from_tuple(shape_outputs) def __getattr__(self, name: str) -> Any: """ @@ -267,6 +269,11 @@ class AbstractModelClass(ABC): raise ValueError( f"Got different values or arguments for same argument: self.{allow_k}={new_v_attr.__class__} and '{allow_k}': {new_v_dic.__class__}") + @staticmethod + def __extract_from_tuple(tup): + """Return element of tuple if it contains only a single element.""" + return tup[0] if isinstance(tup, tuple) and len(tup) == 1 else tup + @staticmethod def __compare_keras_optimizers(first, second): if first.__class__ == second.__class__ and first.__module__ == 'keras.optimizers': @@ -334,24 +341,18 @@ class MyLittleModel(AbstractModelClass): Dense layer. """ - def __init__(self, window_history_size, window_lead_time, channels): + def __init__(self, shape_inputs, shape_outputs): """ Sets model and loss depending on the given arguments. - :param activation: activation function - :param window_history_size: number of historical time steps included in the input data - :param channels: number of variables used in input data - :param regularizer: <not used here> - :param dropout_rate: dropout rate used in the model [0, 1) - :param window_lead_time: number of time steps to forecast in the output layer + :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) + :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast)) """ - super().__init__() + assert len(shape_inputs) == 1 + super().__init__(shape_inputs[0], shape_outputs[0]) # settings - self.window_history_size = window_history_size - self.window_lead_time = window_lead_time - self.channels = channels[0] self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) self.activation = keras.layers.PReLU @@ -364,17 +365,10 @@ class MyLittleModel(AbstractModelClass): def set_model(self): """ Build the model. - - :param activation: activation function - :param window_history_size: number of historical time steps included in the input data - :param channels: number of variables used in input data - :param dropout_rate: dropout rate used in the model [0, 1) - :param window_lead_time: number of time steps to forecast in the output layer - :return: built keras model """ # add 1 to window_size to include current time step t0 - x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels)) + x_input = keras.layers.Input(shape=self.shape_inputs) x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input) x_in = self.activation(name='{}_conv_act'.format("major"))(x_in) x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in) @@ -385,13 +379,12 @@ class MyLittleModel(AbstractModelClass): x_in = self.activation()(x_in) x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in) x_in = self.activation()(x_in) - x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in) + x_in = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("major"))(x_in) out_main = self.activation()(x_in) self.model = keras.Model(inputs=x_input, outputs=[out_main]) def set_compile_options(self): self.initial_lr = 1e-2 - # self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9) self.optimizer = keras.optimizers.adam(lr=self.initial_lr) self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) @@ -407,24 +400,18 @@ class MyBranchedModel(AbstractModelClass): Dense layer. """ - def __init__(self, window_history_size, window_lead_time, channels): + def __init__(self, shape_inputs, shape_outputs): """ Sets model and loss depending on the given arguments. - :param activation: activation function - :param window_history_size: number of historical time steps included in the input data - :param channels: number of variables used in input data - :param regularizer: <not used here> - :param dropout_rate: dropout rate used in the model [0, 1) - :param window_lead_time: number of time steps to forecast in the output layer + :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) + :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast)) """ - super().__init__() + assert len(shape_inputs) == 1 + super().__init__(shape_inputs[0], shape_outputs[0]) # settings - self.window_history_size = window_history_size - self.window_lead_time = window_lead_time - self.channels = channels[0] self.dropout_rate = 0.1 self.regularizer = keras.regularizers.l2(0.1) self.activation = keras.layers.PReLU @@ -437,32 +424,25 @@ class MyBranchedModel(AbstractModelClass): def set_model(self): """ Build the model. - - :param activation: activation function - :param window_history_size: number of historical time steps included in the input data - :param channels: number of variables used in input data - :param dropout_rate: dropout rate used in the model [0, 1) - :param window_lead_time: number of time steps to forecast in the output layer - :return: built keras model """ # add 1 to window_size to include current time step t0 - x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels)) + x_input = keras.layers.Input(shape=self.shape_inputs) x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input) x_in = self.activation(name='{}_conv_act'.format("major"))(x_in) x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in) x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in) x_in = keras.layers.Dense(64, name='{}_Dense_64'.format("major"))(x_in) x_in = self.activation()(x_in) - out_minor_1 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_1"))(x_in) + out_minor_1 = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("minor_1"))(x_in) out_minor_1 = self.activation(name="minor_1")(out_minor_1) x_in = keras.layers.Dense(32, name='{}_Dense_32'.format("major"))(x_in) x_in = self.activation()(x_in) - out_minor_2 = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("minor_2"))(x_in) + out_minor_2 = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("minor_2"))(x_in) out_minor_2 = self.activation(name="minor_2")(out_minor_2) x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in) x_in = self.activation()(x_in) - x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in) + x_in = keras.layers.Dense(self.shape_outputs, name='{}_Dense'.format("major"))(x_in) out_main = self.activation(name="main")(x_in) self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main]) @@ -477,24 +457,18 @@ class MyBranchedModel(AbstractModelClass): class MyTowerModel(AbstractModelClass): - def __init__(self, window_history_size, window_lead_time, channels): + def __init__(self, shape_inputs, shape_outputs): """ Sets model and loss depending on the given arguments. - :param activation: activation function - :param window_history_size: number of historical time steps included in the input data - :param channels: number of variables used in input data - :param regularizer: <not used here> - :param dropout_rate: dropout rate used in the model [0, 1) - :param window_lead_time: number of time steps to forecast in the output layer + :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) + :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast)) """ - super().__init__() + assert len(shape_inputs) == 1 + super().__init__(shape_inputs[0], shape_outputs[0]) # settings - self.window_history_size = window_history_size - self.window_lead_time = window_lead_time - self.channels = channels[0] self.dropout_rate = 1e-2 self.regularizer = keras.regularizers.l2(0.1) self.initial_lr = 1e-2 @@ -510,13 +484,6 @@ class MyTowerModel(AbstractModelClass): def set_model(self): """ Build the model. - - :param activation: activation function - :param window_history_size: number of historical time steps included in the input data - :param channels: number of variables used in input data - :param dropout_rate: dropout rate used in the model [0, 1) - :param window_lead_time: number of time steps to forecast in the output layer - :return: built keras model """ activation = self.activation conv_settings_dict1 = { @@ -550,9 +517,7 @@ class MyTowerModel(AbstractModelClass): ########################################## inception_model = InceptionModelBase() - X_input = keras.layers.Input( - shape=( - self.window_history_size + 1, 1, self.channels)) # add 1 to window_size to include current time step t0 + X_input = keras.layers.Input(shape=self.shape_inputs) X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=self.regularizer, @@ -574,7 +539,7 @@ class MyTowerModel(AbstractModelClass): # out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=self.dropout_rate, # reduction_filter=64, inner_neurons=64, output_neurons=self.window_lead_time) - out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time, + out_main = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.shape_outputs, output_activation='linear', reduction_filter=64, name='Main', bound_weight=True, dropout_rate=self.dropout_rate, kernel_regularizer=self.regularizer @@ -589,24 +554,18 @@ class MyTowerModel(AbstractModelClass): class MyPaperModel(AbstractModelClass): - def __init__(self, window_history_size, window_lead_time, channels): + def __init__(self, shape_inputs, shape_outputs): """ Sets model and loss depending on the given arguments. - :param activation: activation function - :param window_history_size: number of historical time steps included in the input data - :param channels: number of variables used in input data - :param regularizer: <not used here> - :param dropout_rate: dropout rate used in the model [0, 1) - :param window_lead_time: number of time steps to forecast in the output layer + :param shape_inputs: list of input shapes (expect len=1 with shape=(window_hist, station, variables)) + :param shape_outputs: list of output shapes (expect len=1 with shape=(window_forecast)) """ - super().__init__() + assert len(shape_inputs) == 1 + super().__init__(shape_inputs[0], shape_outputs[0]) # settings - self.window_history_size = window_history_size - self.window_lead_time = window_lead_time - self.channels = channels[0] self.dropout_rate = .3 self.regularizer = keras.regularizers.l2(0.001) self.initial_lr = 1e-3 @@ -671,9 +630,7 @@ class MyPaperModel(AbstractModelClass): ########################################## inception_model = InceptionModelBase() - X_input = keras.layers.Input( - shape=( - self.window_history_size + 1, 1, self.channels)) # add 1 to window_size to include current time step t0 + X_input = keras.layers.Input(shape=self.shape_inputs) pad_size = PadUtils.get_padding_for_same(first_kernel) # X_in = adv_pad.SymmetricPadding2D(padding=pad_size)(X_input) @@ -691,7 +648,7 @@ class MyPaperModel(AbstractModelClass): padding=self.padding) # out_minor1 = flatten_tail(X_in, 'minor_1', False, self.dropout_rate, self.window_lead_time, # self.activation, 32, 64) - out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.window_lead_time, + out_minor1 = flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=self.shape_outputs, output_activation='linear', reduction_filter=32, name='minor_1', bound_weight=False, dropout_rate=self.dropout_rate, kernel_regularizer=self.regularizer @@ -709,7 +666,7 @@ class MyPaperModel(AbstractModelClass): # batch_normalisation=True) ############################################# - out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.window_lead_time, + out_main = flatten_tail(X_in, inner_neurons=64 * 2, activation=activation, output_neurons=self.shape_outputs, output_activation='linear', reduction_filter=64 * 2, name='Main', bound_weight=False, dropout_rate=self.dropout_rate, kernel_regularizer=self.regularizer diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index dc537eb1cd3e5cf04fbdddee3017e4ace7f7bfca..5acdac0193a2b4f870ba8de351d6b9c8a24af1b0 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -89,9 +89,13 @@ class ModelSetup(RunEnvironment): self.compile_model() def _set_channels(self): - """Set channels as number of variables of train generator.""" - channels = list(map(lambda x: x[0].shape[-1], self.data_store.get("data_collection", "train")[0].get_X())) - self.data_store.set("channels", channels, self.scope) + """Set input and output shapes from train collection.""" + # channels = list(map(lambda x: x[0].shape[-1], self.data_store.get("data_collection", "train")[0].get_X())) + # self.data_store.set("channels", channels, self.scope) + shape = list(map(lambda x: x.shape[1:], self.data_store.get("data_collection", "train")[0].get_X())) + self.data_store.set("shape_inputs", shape, self.scope) + shape = list(map(lambda y: y.shape[1:], self.data_store.get("data_collection", "train")[0].get_Y())) + self.data_store.set("shape_outputs", shape, self.scope) def compile_model(self): """ @@ -128,8 +132,8 @@ class ModelSetup(RunEnvironment): logging.info('no weights to reload...') def build_model(self): - """Build model using window_history_size, window_lead_time and channels from data store.""" - args_list = ["window_history_size", "window_lead_time", "channels"] + """Build model using input and output shapes from data store.""" + args_list = ["shape_inputs", "shape_outputs"] args = self.data_store.create_args_dict(args_list, self.scope) model = self.data_store.get("model_class") self.model = model(**args)