Commit 9796abcc authored by lukas leufen's avatar lukas leufen 👻
Browse files

Merge branch 'lukas_issue389_feat_parallel-make_prediction-in-postprocessing' into 'develop'

Resolve "parallel make_prediction in postprocessing"

See merge request !427
parents 003dc3ea 2575bd85
Pipeline #101911 passed with stages
in 17 minutes and 3 seconds
......@@ -125,8 +125,9 @@ class DefaultDataHandler(AbstractDataHandler):
def get_data(self, upsampling=False, as_numpy=True):
self._load()
X = self.get_X(upsampling, as_numpy)
Y = self.get_Y(upsampling, as_numpy)
as_numpy_X, as_numpy_Y = as_numpy if isinstance(as_numpy, tuple) else (as_numpy, as_numpy)
X = self.get_X(upsampling, as_numpy_X)
Y = self.get_Y(upsampling, as_numpy_Y)
self._reset_data()
return X, Y
......
......@@ -144,8 +144,8 @@ class KerasIterator(keras.utils.Sequence):
mod_rank = self._get_model_rank()
for data in self._collection:
logging.debug(f"prepare batches for {str(data)}")
X = data.get_X(upsampling=self.upsampling)
Y = [data.get_Y(upsampling=self.upsampling)[0] for _ in range(mod_rank)]
X, _Y = data.get_data(upsampling=self.upsampling)
Y = [_Y[0] for _ in range(mod_rank)]
if self.upsampling:
X, Y = self._permute_data(X, Y)
if remaining is not None:
......
......@@ -10,7 +10,6 @@ import sys
import traceback
from typing import Dict, Tuple, Union, List, Callable
import tensorflow.keras as keras
import numpy as np
import pandas as pd
import xarray as xr
......@@ -695,6 +694,7 @@ class PostProcessing(RunEnvironment):
logging.info(f"start train_ols_model on train data")
self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
@TimeTrackingWrapper
def make_prediction(self, subset):
"""
Create predictions for NN, OLS, and persistence and add true observation as reference.
......@@ -707,7 +707,7 @@ class PostProcessing(RunEnvironment):
logging.info(f"start make_prediction for {subset_type}")
time_dimension = self.data_store.get("time_dim")
window_dim = self.data_store.get("window_dim")
subset_type = subset.name
for i, data in enumerate(subset):
input_data = data.get_X()
target_data = data.get_Y(as_numpy=False)
......
......@@ -106,6 +106,9 @@ class DummyData:
Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables
return [Y1, Y2]
def get_data(self, upsampling=False, as_numpy=True):
return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
class TestKerasIterator:
......
......@@ -150,3 +150,6 @@ class DummyData:
Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5)) # samples, window
Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3)) # samples, window
return [Y1, Y2]
def get_data(self, upsampling=False, as_numpy=True):
return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment