Skip to content
Snippets Groups Projects
Commit ad302db1 authored by lukas leufen's avatar lukas leufen
Browse files

include all bugfixes

parents 811bd7d6 b146b12b
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
......@@ -31,7 +31,7 @@ class OrdinaryLeastSquaredModel:
self.y = data_y
def predict(self, data):
data = sm.add_constant(self.reshape_xarray_to_numpy(data))
data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
return np.atleast_2d(self.model.predict(data))
@staticmethod
......
......@@ -100,5 +100,5 @@ class ModelSetup(RunEnvironment):
def plot_model(self): # pragma: no cover
with tf.device("/cpu:0"):
file_name = f"{self.model_name.split(sep='.')[0]}.pdf"
file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
......@@ -137,9 +137,10 @@ class PreProcessing(RunEnvironment):
for station in all_stations:
t_inner.run()
try:
# (history, label) = data_gen[station]
data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
save_local_tmp_storage=save_tmp)
if data.history is None:
raise AttributeError
valid_stations.append(station)
logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
logging.debug(f"{station}: loading time = {t_inner}")
......
from src.model_modules.linear_model import OrdinaryLeastSquaredModel
class TestOrdinaryLeastSquareModel:
def test_constant_input_variable(self):
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment