__author__ = "Felix Kleinert, Lukas Leufen" __date__ = '2019-12-11' import numpy as np import statsmodels.api as sm class OrdinaryLeastSquaredModel: def __init__(self, generator): self.x = [] self.y = [] self.generator = generator self.model = self.train_ols_model_from_generator() def train_ols_model_from_generator(self): self.set_x_y_from_generator() self.x = sm.add_constant(self.x) return self.ordinary_least_squared_model(self.x, self.y) def set_x_y_from_generator(self): data_x = None data_y = None for item in self.generator: x = self.reshape_xarray_to_numpy(item[0]) y = item[1].values data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y self.x = data_x self.y = data_y def predict(self, data): data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add") return np.atleast_2d(self.model.predict(data)) @staticmethod def reshape_xarray_to_numpy(data): shape = data.values.shape res = data.values.reshape(shape[0], shape[1] * shape[3]) return res @staticmethod def ordinary_least_squared_model(x, y): ols_model = sm.OLS(y, x) return ols_model.fit()