Skip to content
Snippets Groups Projects
Commit ad302db1 authored by lukas leufen's avatar lukas leufen
Browse files

include all bugfixes

parents 811bd7d6 b146b12b
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
...@@ -31,7 +31,7 @@ class OrdinaryLeastSquaredModel: ...@@ -31,7 +31,7 @@ class OrdinaryLeastSquaredModel:
self.y = data_y self.y = data_y
def predict(self, data): def predict(self, data):
data = sm.add_constant(self.reshape_xarray_to_numpy(data)) data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
return np.atleast_2d(self.model.predict(data)) return np.atleast_2d(self.model.predict(data))
@staticmethod @staticmethod
......
...@@ -100,5 +100,5 @@ class ModelSetup(RunEnvironment): ...@@ -100,5 +100,5 @@ class ModelSetup(RunEnvironment):
def plot_model(self): # pragma: no cover def plot_model(self): # pragma: no cover
with tf.device("/cpu:0"): with tf.device("/cpu:0"):
file_name = f"{self.model_name.split(sep='.')[0]}.pdf" file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
...@@ -137,9 +137,10 @@ class PreProcessing(RunEnvironment): ...@@ -137,9 +137,10 @@ class PreProcessing(RunEnvironment):
for station in all_stations: for station in all_stations:
t_inner.run() t_inner.run()
try: try:
# (history, label) = data_gen[station]
data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp, data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp,
save_local_tmp_storage=save_tmp) save_local_tmp_storage=save_tmp)
if data.history is None:
raise AttributeError
valid_stations.append(station) valid_stations.append(station)
logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}')
logging.debug(f"{station}: loading time = {t_inner}") logging.debug(f"{station}: loading time = {t_inner}")
......
from src.model_modules.linear_model import OrdinaryLeastSquaredModel
class TestOrdinaryLeastSquareModel:
def test_constant_input_variable(self):
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment