diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py index 08ac7cab21567166149d7c05f1fd6450760856a5..21658ea52f194863ad709ae7efbea96a81d29cd9 100644 --- a/mlair/helpers/testing.py +++ b/mlair/helpers/testing.py @@ -170,8 +170,6 @@ def check_nested_equality(obj1, obj2, precision=None, skip_args=None): message = f"{obj1}!={obj2}\n{obj1} and {obj2} do not match" assert obj1 == obj2 except AssertionError: - message = message.split("\n") - logging.info(message[0]) - logging.debug(message[1]) + logging.info(message) return False return True diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 64d1bfa20a81c11b3aca79c74c057e06d0b510b8..fc1ae4b7ad63a51b623aacb3d846d33ca3a482e0 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -412,19 +412,23 @@ class PreProcessing(RunEnvironment): logging.info(f"load snapshot for preprocessing from {file}") with open(file, "rb") as f: snapshot = dill.load(f) + excluded_params = ["activation", "activation_output", "add_dense_layer", "batch_normalization", "batch_path", + "batch_size", "block_length", "bootstrap_method", "bootstrap_path", "bootstrap_type", + "competitor_path", "competitors", "create_new_bootstraps", "create_new_model", + "create_snapshot", "data_collection", "debug_mode", "dense_layer_configuration", + "do_uncertainty_estimate", "dropout", "dropout_rnn", "early_stopping_epochs", "epochs", + "evaluate_competitors", "evaluate_feature_importance", "experiment_name", "experiment_path", + "exponent_last_layer", "forecast_path", "fraction_of_training", "hostname", "hpc_hosts", + "kernel_regularizer", "kernel_size", "layer_configuration", "log_level_stream", + "logging_path", "login_nodes", "loss_type", "loss_weights", "max_number_multiprocessing", + "model_class", "model_display_name", "model_path", "n_boots", "n_hidden", "n_layer", + "neighbors", "plot_list", "plot_path", "regularizer", "restore_best_model_weights", + "snapshot_load_path", "snapshot_path", "stations", "tmp_path", "train_model", + "transformation", "use_multiprocessing", ] - excluded_params = ["batch_path", "batch_size", "block_length", "bootstrap_method", "bootstrap_path", - "bootstrap_type", "competitor_path", "competitors", "create_new_bootstraps", - "create_new_model", "create_snapshot", "data_collection", "debug_mode", - "do_uncertainty_estimate", "early_stopping_epochs", "epochs", "evaluate_competitors", - "evaluate_feature_importance", "experiment_name", "experiment_path", "forecast_path", - "fraction_of_training", "hostname", "hpc_hosts", "log_level_stream", "logging_path", - "login_nodes", "max_number_multiprocessing", "model_class", "model_path", "n_boots", - "neighbors", "plot_list", "plot_path", "restore_best_model_weights", "snapshot_load_path", - "snapshot_path", "stations", "tmp_path", "train_model", "transformation", - "use_multiprocessing", ] data_handler = self.data_store.get("data_handler") - excluded_params = list(set(excluded_params + data_handler.store_attributes())) + model_class = self.data_store.get("model_class") + excluded_params = list(set(excluded_params + data_handler.store_attributes() + model_class.requirements())) if check_nested_equality(self.data_store._store, snapshot._store, skip_args=excluded_params) is True: self.update_datastore(snapshot, excluded_params=remove_items(excluded_params, ["transformation",