From 4608b470df53c7efc1888d1bce5132142d2346ea Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Thu, 12 Dec 2019 14:40:42 +0100
Subject: [PATCH] refactored ordinary least squared model into class

---
 src/model_modules/linear_model.py | 42 ++++++++++++++++++++++++++++---
 1 file changed, 39 insertions(+), 3 deletions(-)

diff --git a/src/model_modules/linear_model.py b/src/model_modules/linear_model.py
index 9d7ba7e1..17a9b232 100644
--- a/src/model_modules/linear_model.py
+++ b/src/model_modules/linear_model.py
@@ -3,8 +3,44 @@ __date__ = '2019-12-11'
 
 
 import statsmodels.api as sm
+import numpy as np
 
 
-def ordinary_least_squared_model(x, y):
-    ols_model = sm.OLS(y, x)
-    return ols_model.fit()
+class OrdinaryLeastSquaredModel:
+
+    def __init__(self, generator):
+        self.x = []
+        self.y = []
+        self.generator = generator
+        self.model = self.train_ols_model_from_generator()
+
+    def train_ols_model_from_generator(self):
+        self.set_x_y_from_generator()
+        self.x = sm.add_constant(self.x)
+        return self.ordinary_least_squared_model(self.x, self.y)
+
+    def set_x_y_from_generator(self):
+        data_x = None
+        data_y = None
+        for item in self.generator:
+            x = self.reshape_xarray_to_numpy(item[0])
+            y = item[1].values
+            data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x
+            data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y
+        self.x = data_x
+        self.y = data_y
+
+    def predict(self, data):
+        data = sm.add_constant(self.reshape_xarray_to_numpy(data))
+        return self.model.predict(data)
+
+    @staticmethod
+    def reshape_xarray_to_numpy(data):
+        shape = data.values.shape
+        res = data.values.reshape(shape[0], shape[1] * shape[3])
+        return res
+
+    @staticmethod
+    def ordinary_least_squared_model(x, y):
+        ols_model = sm.OLS(y, x)
+        return ols_model.fit()
-- 
GitLab