Skip to content
Snippets Groups Projects
Commit 4608b470 authored by lukas leufen's avatar lukas leufen
Browse files

refactored ordinary least squared model into class

parent 057b47fe
No related branches found
No related tags found
2 merge requests!24include recent development,!23Lukas issue018 feat evaluate train val
......@@ -3,8 +3,44 @@ __date__ = '2019-12-11'
import statsmodels.api as sm
import numpy as np
def ordinary_least_squared_model(x, y):
ols_model = sm.OLS(y, x)
return ols_model.fit()
class OrdinaryLeastSquaredModel:
def __init__(self, generator):
self.x = []
self.y = []
self.generator = generator
self.model = self.train_ols_model_from_generator()
def train_ols_model_from_generator(self):
self.set_x_y_from_generator()
self.x = sm.add_constant(self.x)
return self.ordinary_least_squared_model(self.x, self.y)
def set_x_y_from_generator(self):
data_x = None
data_y = None
for item in self.generator:
x = self.reshape_xarray_to_numpy(item[0])
y = item[1].values
data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x
data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y
self.x = data_x
self.y = data_y
def predict(self, data):
data = sm.add_constant(self.reshape_xarray_to_numpy(data))
return self.model.predict(data)
@staticmethod
def reshape_xarray_to_numpy(data):
shape = data.values.shape
res = data.values.reshape(shape[0], shape[1] * shape[3])
return res
@staticmethod
def ordinary_least_squared_model(x, y):
ols_model = sm.OLS(y, x)
return ols_model.fit()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment