diff --git a/.gitignore b/.gitignore index 31cfe991f3cb38160cc5a2b330ca4725ebaa834f..11b7c159fdc450144f4103db66af1eabe42791a0 100644 --- a/.gitignore +++ b/.gitignore @@ -55,5 +55,6 @@ Thumbs.db htmlcov/ .pytest_cache /test/data/ +/test/test_modules/data/ report.html /TestExperiment/ diff --git a/run.py b/run.py index 1579ae35f0d270dc0d2529cf1b6d36bc410e317a..ea8c04ebde02a80b899a356eb0f7794055abe2d6 100644 --- a/run.py +++ b/run.py @@ -4,20 +4,26 @@ __date__ = '2019-11-14' import logging import argparse + from src.modules.experiment_setup import ExperimentSetup -from src.modules import run, PreProcessing, Training, PostProcessing +from src.modules.run_environment import RunEnvironment +from src.modules.pre_processing import PreProcessing +from src.modules.model_setup import ModelSetup +from src.modules.modules import Training, PostProcessing -def main(): +def main(parser_args): - with run(): - exp_setup = ExperimentSetup(args, trainable=True, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) + with RunEnvironment(): + ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], + station_type='background') + PreProcessing() - PreProcessing(exp_setup) + ModelSetup() - Training(exp_setup) + Training() - PostProcessing(exp_setup) + PostProcessing() if __name__ == "__main__": @@ -30,6 +36,4 @@ if __name__ == "__main__": help="set experiment date as string") args = parser.parse_args() - experiment = ExperimentSetup(args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) - a = 1 - # main() + main(args) diff --git a/src/data_generator.py b/src/data_generator.py index 860791235f111a7ffb151f2b06424be76dc8eba7..4e7dda9363c226c5fa92d03f1dbae6470e48d496 100644 --- a/src/data_generator.py +++ b/src/data_generator.py @@ -20,7 +20,7 @@ class DataGenerator(keras.utils.Sequence): def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, - interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history: int = 7, + interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, window_lead_time: int = 4, transform_method: str = "standardise", **kwargs): self.data_path = os.path.abspath(data_path) self.network = network @@ -32,7 +32,7 @@ class DataGenerator(keras.utils.Sequence): self.station_type = station_type self.interpolate_method = interpolate_method self.limit_nan_fill = limit_nan_fill - self.window_history = window_history + self.window_history_size = window_history_size self.window_lead_time = window_lead_time self.transform_method = transform_method self.kwargs = kwargs @@ -100,7 +100,7 @@ class DataGenerator(keras.utils.Sequence): **self.kwargs) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) data.transform("datetime", method=self.transform_method) - data.make_history_window(self.interpolate_dim, self.window_history) + data.make_history_window(self.interpolate_dim, self.window_history_size) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.history_label_nan_remove(self.interpolate_dim) return data diff --git a/src/datastore.py b/src/datastore.py index bb8474a04b503b1ff50fcea1b7e5f8bbd1d9ebea..5f0df67573dd510fdc4f04d0cc632b36c5082959 100644 --- a/src/datastore.py +++ b/src/datastore.py @@ -2,7 +2,7 @@ __author__ = 'Lukas Leufen' __date__ = '2019-11-22' -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Dict from abc import ABC @@ -36,7 +36,7 @@ class AbstractDataStore(ABC): """ def __init__(self): # empty initialise the data-store variables - self._store = {} + self._store: Dict = {} def put(self, name: str, obj: Any, scope: str) -> None: """ @@ -89,6 +89,15 @@ class AbstractDataStore(ABC): def clear_data_store(self) -> None: self._store = {} + def create_args_dict(self, arg_list: List[str], scope: str = "general"): + args = {} + for arg in arg_list: + try: + args[arg] = self.get(arg, scope) + except (NameNotFoundInDataStore, NameNotFoundInScope): + pass + return args + class DataStoreByVariable(AbstractDataStore): diff --git a/src/flatten.py b/src/flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..1166cf328ee2c8326b0628c4a93184a2dece16fe --- /dev/null +++ b/src/flatten.py @@ -0,0 +1,32 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-12-02' + +import keras +from typing import Callable + + +def flatten_tail(input_X: keras.layers, name: str, bound_weight: bool = False, dropout_rate: float = 0.0, + window_lead_time: int = 4, activation: Callable = keras.activations.relu, + reduction_filter: int = 64, first_dense: int = 64): + + X_in = keras.layers.Conv2D(reduction_filter, (1, 1), padding='same', name='{}_Conv_1x1'.format(name))(input_X) + + X_in = activation(name='{}_conv_act'.format(name))(X_in) + + X_in = keras.layers.Flatten(name='{}'.format(name))(X_in) + + X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_1'.format(name))(X_in) + X_in = keras.layers.Dense(first_dense, kernel_regularizer=keras.regularizers.l2(0.01), + name='{}_Dense_1'.format(name))(X_in) + if bound_weight: + X_in = keras.layers.Activation('tanh')(X_in) + else: + try: + X_in = activation(name='{}_act'.format(name))(X_in) + except: + X_in = activation()(X_in) + + X_in = keras.layers.Dropout(dropout_rate, name='{}_Dropout_2'.format(name))(X_in) + out = keras.layers.Dense(window_lead_time, activation='linear', kernel_regularizer=keras.regularizers.l2(0.01), + name='{}_Dense_2'.format(name))(X_in) + return out diff --git a/src/inception_model.py b/src/inception_model.py index 64b5e09674dd9fccd8282b80c4fb9e87bbd5ea5c..126dc15320f4b0ac7ff660b709601f083700d01f 100644 --- a/src/inception_model.py +++ b/src/inception_model.py @@ -3,6 +3,7 @@ __date__ = '2019-10-22' import keras import keras.layers as layers +import logging class InceptionModelBase: @@ -51,7 +52,7 @@ class InceptionModelBase: regularizer = kwargs.get('regularizer', keras.regularizers.l2(0.01)) bn_settings = kwargs.get('bn_settings', {}) act_settings = kwargs.get('act_settings', {}) - print(f'Inception Block with activation: {activation}') + logging.debug(f'Inception Block with activation: {activation}') block_name = f'Block_{self.number_of_blocks}{self.block_part_name()}_{tower_kernel[0]}x{tower_kernel[1]}' diff --git a/src/modules/experiment_setup.py b/src/modules/experiment_setup.py index a20f0b83e9828550d2f717502b5371c2c1ad7e9a..472173c7df371bcb139672d02bfa6723dbd0658e 100644 --- a/src/modules/experiment_setup.py +++ b/src/modules/experiment_setup.py @@ -28,7 +28,7 @@ class ExperimentSetup(RunEnvironment): """ def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, station_type=None, variables=None, - statistics_per_var=None, start=None, end=None, window_history=None, target_var="o3", target_dim=None, + statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None, @@ -58,7 +58,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("statistics_per_var", statistics_per_var, default=self.data_store.get("var_all_dict", "general")) self._set_param("start", start, default="1997-01-01", scope="general") self._set_param("end", end, default="2017-12-31", scope="general") - self._set_param("window_history", window_history, default=13) + self._set_param("window_history_size", window_history_size, default=13) # target self._set_param("target_var", target_var, default="o3") diff --git a/src/modules/model_setup.py b/src/modules/model_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..925d83d733b8cb5715b515856e09e773737778a9 --- /dev/null +++ b/src/modules/model_setup.py @@ -0,0 +1,176 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-12-02' + + +import keras +from keras import losses, layers +from keras.callbacks import ModelCheckpoint +from keras.regularizers import l2 +from keras.optimizers import Adam, SGD +import tensorflow as tf +import logging +import os + +from src.modules.run_environment import RunEnvironment +from src.helpers import l_p_loss, LearningRateDecay +from src.inception_model import InceptionModelBase +from src.flatten import flatten_tail + + +class ModelSetup(RunEnvironment): + + def __init__(self): + + # create run framework + super().__init__() + self.model = None + self.model_name = self.data_store.get("experiment_name", "general") + "model-best.h5" + self.scope = "general.model" + self._run() + + def _run(self): + + # create checkpoint + self._set_checkpoint() + + # set all model settings + self.my_model_settings() + + # build model graph using settings from my_model_settings() + self.build_model() + + # plot model structure + self.plot_model() + + # load weights if no training shall be performed + if self.data_store.get("trainable", self.scope) is False: + self.load_weights() + + # compile model + self.compile_model() + + def compile_model(self): + optimizer = self.data_store.get("optimizer", self.scope) + loss = self.data_store.get("loss", self.scope) + self.model.compile(optimizer=optimizer, loss=loss, metrics=["mse", "mae"]) + self.data_store.put("model", self.model, self.scope) + + def _set_checkpoint(self): + checkpoint = ModelCheckpoint(self.model_name, verbose=1, monitor='val_loss', save_best_only=True, mode='auto') + self.data_store.put("checkpoint", checkpoint, self.scope) + + def load_weights(self): + try: + logging.debug('reload weights...') + self.model.load_weights(self.model_name) + except OSError: + logging.debug('no weights to reload...') + + def build_model(self): + args_list = ["activation", "window_history_size", "channels", "regularizer", "dropout_rate", "window_lead_time"] + args = self.data_store.create_args_dict(args_list, self.scope) + self.model = my_model(**args) + + def plot_model(self): # pragma: no cover + with tf.device("/cpu:0"): + path = self.data_store.get("experiment_path", "general") + name = self.data_store.get("experiment_name", "general") + "model.pdf" + file_name = os.path.join(path, name) + keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) + + def my_model_settings(self): + + # channels + X, _ = self.data_store.get("generator", "general.train")[0] + channels = X.shape[-1] # input variables + self.data_store.put("channels", channels, self.scope) + + # dropout + self.data_store.put("dropout_rate", 0.1, self.scope) + + # regularizer + self.data_store.put("regularizer", l2(0.1), self.scope) + + # learning rate + initial_lr = 1e-2 + self.data_store.put("initial_lr", initial_lr, self.scope) + optimizer = SGD(lr=initial_lr, momentum=0.9) + # optimizer=Adam(lr=initial_lr, amsgrad=True) + self.data_store.put("optimizer", optimizer, self.scope) + self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope) + + # learning settings + self.data_store.put("epochs", 2, self.scope) + self.data_store.put("batch_size", int(256), self.scope) + + # activation + activation = layers.PReLU # ELU #LeakyReLU keras.activations.tanh # + self.data_store.put("activation", activation, self.scope) + + # set los + loss_all = my_loss() + self.data_store.put("loss", loss_all, self.scope) + + +def my_loss(): + loss = l_p_loss(4) + keras_loss = losses.mean_squared_error + loss_all = [loss] + [keras_loss] + return loss_all + + +def my_model(activation, window_history_size, channels, regularizer, dropout_rate, window_lead_time): + + conv_settings_dict1 = { + 'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation}, + 'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, + 'tower_3': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (1, 1), 'activation': activation}, + } + + pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation} + + conv_settings_dict2 = {'tower_1': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (3, 1), + 'activation': activation}, + 'tower_2': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (5, 1), + 'activation': activation}, + 'tower_3': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (1, 1), + 'activation': activation}, + } + pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation} + + conv_settings_dict3 = {'tower_1': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1), + 'activation': activation}, + 'tower_2': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (5, 1), + 'activation': activation}, + 'tower_3': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (1, 1), + 'activation': activation}, + } + + pool_settings_dict3 = {'pool_kernel': (3, 1), 'tower_filter': 32, 'activation': activation} + + ########################################## + inception_model = InceptionModelBase() + + X_input = layers.Input(shape=(window_history_size + 1, 1, channels)) # add 1 to window_size to include current time step t0 + + X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=regularizer, + batch_normalisation=True) + + out_minor = flatten_tail(X_in, 'Minor_1', bound_weight=True, activation=activation, dropout_rate=dropout_rate, + reduction_filter=4, first_dense=32, window_lead_time=window_lead_time) + + X_in = layers.Dropout(dropout_rate)(X_in) + + X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=regularizer, + batch_normalisation=True) + + X_in = layers.Dropout(dropout_rate)(X_in) + + X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=regularizer, + batch_normalisation=True) + ############################################# + + out_main = flatten_tail(X_in, 'Main', activation=activation, bound_weight=True, dropout_rate=dropout_rate, + reduction_filter=64, first_dense=64, window_lead_time=window_lead_time) + + return keras.Model(inputs=X_input, outputs=[out_minor, out_main]) diff --git a/src/modules/modules.py b/src/modules/modules.py index 033fd0779d8d140e684103b27fc7c025dedcdb81..888c7e06f0ef34b17f6c3f2fc2da6fe0316282f4 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -8,16 +8,14 @@ from src.modules.pre_processing import PreProcessing class Training(RunEnvironment): - def __init__(self, setup): + def __init__(self): super().__init__() - self.setup = setup class PostProcessing(RunEnvironment): - def __init__(self, setup): + def __init__(self): super().__init__() - self.setup = setup if __name__ == "__main__": diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py index d3056f52bd0a60e0c9e7ed97fa593f3b596898a4..764613ea4558bfdef4cbada668f16608f81d5f95 100644 --- a/src/modules/pre_processing.py +++ b/src/modules/pre_processing.py @@ -1,3 +1,7 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-11-25' + + import logging from typing import Any, Tuple, Dict, List @@ -9,7 +13,7 @@ from src.join import EmptyQueryResult DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] -DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time", "statistics_per_var", "station_type"] +DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "station_type"] class PreProcessing(RunEnvironment): @@ -29,22 +33,26 @@ class PreProcessing(RunEnvironment): # self._run() - def _create_args_dict(self, arg_list, scope="general"): - args = {} - for arg in arg_list: - try: - args[arg] = self.data_store.get(arg, scope) - except (NameNotFoundInDataStore, NameNotFoundInScope): - pass - return args - def _run(self): - args = self._create_args_dict(DEFAULT_ARGS_LIST) - kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST) + args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST) + kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST) valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) self.data_store.put("stations", valid_stations, "general") self.split_train_val_test() + def report_pre_processing(self): + logging.info(20 * '##') + n_train = len(self.data_store.get('generator', 'general.train')) + n_val = len(self.data_store.get('generator', 'general.val')) + n_test = len(self.data_store.get('generator', 'general.test')) + n_total = n_train + n_val + n_test + logging.info(f"Number of all stations: {n_total}") + logging.info(f"Number of training stations: {n_train}") + logging.info(f"Number of val stations: {n_val}") + logging.info(f"Number of test stations: {n_test}") + logging.info(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'general.test')[0][0].shape}" + f"{self.data_store.get('generator', 'general.test')[0][1].shape}") + def split_train_val_test(self): fraction_of_training = self.data_store.get("fraction_of_training", "general") stations = self.data_store.get("stations", "general") @@ -71,8 +79,8 @@ class PreProcessing(RunEnvironment): def create_set_split(self, index_list, set_name): scope = f"general.{set_name}" - args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) - kwargs = self._create_args_dict(DEFAULT_KWARGS_LIST, scope) + args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) + kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope) stations = args["stations"] if self.data_store.get("use_all_stations_on_all_data_sets", scope): set_stations = stations @@ -81,7 +89,7 @@ class PreProcessing(RunEnvironment): logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") set_stations = self.check_valid_stations(args, kwargs, set_stations) self.data_store.put("stations", set_stations, scope) - set_args = self._create_args_dict(DEFAULT_ARGS_LIST, scope) + set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope) data_set = DataGenerator(**set_args, **kwargs) self.data_store.put("generator", data_set, scope) diff --git a/test/test_data_generator.py b/test/test_data_generator.py index 2fe8b8c0b5a7f4f8be9626b0061702acb53ecb6b..6801e064be804da0368520ab0ef81bdba8bca2d3 100644 --- a/test/test_data_generator.py +++ b/test/test_data_generator.py @@ -21,7 +21,7 @@ class TestDataGenerator: assert gen.target_var == 'o3' assert gen.interpolate_method == "linear" assert gen.limit_nan_fill == 1 - assert gen.window_history == 7 + assert gen.window_history_size == 7 assert gen.window_lead_time == 4 assert gen.transform_method == "standardise" assert gen.kwargs == {} @@ -44,7 +44,7 @@ class TestDataGenerator: assert station[0].Stations.data == "DEBW107" assert station[0].data.shape[1:] == (8, 1, 2) assert station[1].data.shape[-1] == gen.window_lead_time - assert station[0].data.shape[1] == gen.window_history + 1 + assert station[0].data.shape[1] == gen.window_history_size + 1 def test_iter(self, gen): assert hasattr(gen, '_iterator') is False diff --git a/test/test_modules.py b/test/test_modules.py deleted file mode 100644 index b28b04f643122b019e912540f228c8ed20be9eeb..0000000000000000000000000000000000000000 --- a/test/test_modules.py +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py index be3db59e2415f28ea63d42b7cc6ced6b2c095700..e1ec57f9e6cace563b6311ee0c1f34fefcb2c7c2 100644 --- a/test/test_modules/test_experiment_setup.py +++ b/test/test_modules/test_experiment_setup.py @@ -72,7 +72,7 @@ class TestExperimentSetup: assert data_store.get("statistics_per_var", "general") == default_var_all_dict assert data_store.get("start", "general") == "1997-01-01" assert data_store.get("end", "general") == "2017-12-31" - assert data_store.get("window_history", "general") == 13 + assert data_store.get("window_history_size", "general") == 13 # target assert data_store.get("target_var", "general") == "o3" assert data_store.get("target_dim", "general") == "variables" @@ -100,7 +100,7 @@ class TestExperimentSetup: var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'}, stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background", variables=["o3", "temp"], - statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history=4, + statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history_size=4, target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1", interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05", @@ -127,7 +127,7 @@ class TestExperimentSetup: 'temp': 'maximum'} assert data_store.get("start", "general") == "1999-01-01" assert data_store.get("end", "general") == "2001-01-01" - assert data_store.get("window_history", "general") == 4 + assert data_store.get("window_history_size", "general") == 4 # target assert data_store.get("target_var", "general") == "temp" assert data_store.get("target_dim", "general") == "target" diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7f8c7a051542ff7fc317c0c92454c28f1d0d70b5 --- /dev/null +++ b/test/test_modules/test_model_setup.py @@ -0,0 +1,58 @@ +import logging +import pytest +import os +import keras + +from src.modules.model_setup import ModelSetup +from src.modules.run_environment import RunEnvironment +from src.data_generator import DataGenerator + + +class TestModelSetup: + + @pytest.fixture + def setup(self): + obj = object.__new__(ModelSetup) + super(ModelSetup, obj).__init__() + obj.scope = "general.modeltest" + obj.model = None + yield obj + RunEnvironment().__del__() + + @pytest.fixture + def gen(self): + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], + 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + + @pytest.fixture + def setup_with_gen(self, setup, gen): + setup.data_store.put("generator", gen, "general.train") + setup.data_store.put("window_history_size", gen.window_history_size, "general") + setup.data_store.put("window_lead_time", gen.window_lead_time, "general") + yield setup + RunEnvironment().__del__() + + def test_set_checkpoint(self, setup): + assert "general.modeltest" not in setup.data_store.search_name("checkpoint") + setup.model_name = "TestName" + setup._set_checkpoint() + assert "general.modeltest" in setup.data_store.search_name("checkpoint") + + def test_my_model_settings(self, setup_with_gen): + setup_with_gen.my_model_settings() + expected = {"channels", "dropout_rate", "regularizer", "initial_lr", "optimizer", "lr_decay", "epochs", + "batch_size", "activation", "loss"} + assert expected <= set(setup_with_gen.data_store.search_scope(setup_with_gen.scope, current_scope_only=True)) + + def test_build_model(self, setup_with_gen): + setup_with_gen.my_model_settings() + assert setup_with_gen.model is None + setup_with_gen.build_model() + assert isinstance(setup_with_gen.model, keras.Model) + + def test_load_weights(self): + pass + + def test_compile_model(self): + pass + diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 1af910ee660510c5667e6170c82c079dcf515bb2..13abe62a2b9199ad8d92528ff5363bd54f1be221 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -87,8 +87,8 @@ class TestPreProcessing: def test_check_valid_stations(self, caplog, obj_with_exp_setup): pre = obj_with_exp_setup caplog.set_level(logging.INFO) - args = pre._create_args_dict(DEFAULT_ARGS_LIST) - kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST) + args = pre.data_store.create_args_dict(DEFAULT_ARGS_LIST) + kwargs = pre.data_store.create_args_dict(DEFAULT_KWARGS_LIST) stations = pre.data_store.get("stations", "general") valid_stations = pre.check_valid_stations(args, kwargs, stations) assert len(valid_stations) < len(stations) @@ -105,11 +105,11 @@ class TestPreProcessing: assert dummy_list[test] == list(range(13, 15)) def test_create_args_dict_default_scope(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} def test_create_args_dict_given_scope(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} def test_create_args_dict_missing_entry(self, obj_super_init): - assert obj_super_init._create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} - assert obj_super_init._create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} + assert obj_super_init.data_store.create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2}