__author__ = "Felix Kleinert, Lukas Leufen"
__date__ = '2019-12-11'


import numpy as np
import statsmodels.api as sm


class OrdinaryLeastSquaredModel:

    def __init__(self, generator):
        self.x = []
        self.y = []
        self.generator = generator
        self.model = self.train_ols_model_from_generator()

    def train_ols_model_from_generator(self):
        self.set_x_y_from_generator()
        self.x = sm.add_constant(self.x)
        return self.ordinary_least_squared_model(self.x, self.y)

    def set_x_y_from_generator(self):
        data_x = None
        data_y = None
        for item in self.generator:
            x = self.reshape_xarray_to_numpy(item[0])
            y = item[1].values
            data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x
            data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y
        self.x = data_x
        self.y = data_y

    def predict(self, data):
        data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
        return np.atleast_2d(self.model.predict(data))

    @staticmethod
    def reshape_xarray_to_numpy(data):
        shape = data.values.shape
        res = data.values.reshape(shape[0], shape[1] * shape[3])
        return res

    @staticmethod
    def ordinary_least_squared_model(x, y):
        ols_model = sm.OLS(y, x)
        return ols_model.fit()