diff --git a/src/model_modules/linear_model.py b/src/model_modules/linear_model.py index 3d5323e1b0303b497c1f26c4e84ee9b968380425..933a108c1b06e1786f75e7f4ebd9b220fbe812dd 100644 --- a/src/model_modules/linear_model.py +++ b/src/model_modules/linear_model.py @@ -31,7 +31,7 @@ class OrdinaryLeastSquaredModel: self.y = data_y def predict(self, data): - data = sm.add_constant(self.reshape_xarray_to_numpy(data)) + data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add") return np.atleast_2d(self.model.predict(data)) @staticmethod diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index d2c8d93fc957ecb2990e99000cbd3588e2d83eef..32ca0d2e82af32d8164d80ac42731e10f431a458 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -100,5 +100,5 @@ class ModelSetup(RunEnvironment): def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): - file_name = f"{self.model_name.split(sep='.')[0]}.pdf" + file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 3263f5c4562eeac321c7ce621df551fdf6373ba0..1d014c9e6f4fc0a9168c4d3d31b1141c39fff2a1 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -137,9 +137,10 @@ class PreProcessing(RunEnvironment): for station in all_stations: t_inner.run() try: - # (history, label) = data_gen[station] data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp, save_local_tmp_storage=save_tmp) + if data.history is None: + raise AttributeError valid_stations.append(station) logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") diff --git a/test/test_model_modules/test_linear_model.py b/test/test_model_modules/test_linear_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e10e9db04ba041d61d6ebcf5de3a23380c8ebe --- /dev/null +++ b/test/test_model_modules/test_linear_model.py @@ -0,0 +1,8 @@ + +from src.model_modules.linear_model import OrdinaryLeastSquaredModel + + +class TestOrdinaryLeastSquareModel: + + def test_constant_input_variable(self): + pass