diff --git a/src/model_modules/linear_model.py b/src/model_modules/linear_model.py index 9d7ba7e1a65ce49e10da9e4d9f0131cafc35b3d1..17a9b2326ab6ba1829ee4f65f0161de887e70778 100644 --- a/src/model_modules/linear_model.py +++ b/src/model_modules/linear_model.py @@ -3,8 +3,44 @@ __date__ = '2019-12-11' import statsmodels.api as sm +import numpy as np -def ordinary_least_squared_model(x, y): - ols_model = sm.OLS(y, x) - return ols_model.fit() +class OrdinaryLeastSquaredModel: + + def __init__(self, generator): + self.x = [] + self.y = [] + self.generator = generator + self.model = self.train_ols_model_from_generator() + + def train_ols_model_from_generator(self): + self.set_x_y_from_generator() + self.x = sm.add_constant(self.x) + return self.ordinary_least_squared_model(self.x, self.y) + + def set_x_y_from_generator(self): + data_x = None + data_y = None + for item in self.generator: + x = self.reshape_xarray_to_numpy(item[0]) + y = item[1].values + data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x + data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y + self.x = data_x + self.y = data_y + + def predict(self, data): + data = sm.add_constant(self.reshape_xarray_to_numpy(data)) + return self.model.predict(data) + + @staticmethod + def reshape_xarray_to_numpy(data): + shape = data.values.shape + res = data.values.reshape(shape[0], shape[1] * shape[3]) + return res + + @staticmethod + def ordinary_least_squared_model(x, y): + ols_model = sm.OLS(y, x) + return ols_model.fit()