"""Calculate ordinary least squared model."""

__author__ = "Felix Kleinert, Lukas Leufen"
__date__ = '2019-12-11'

import numpy as np
import statsmodels.api as sm


class OrdinaryLeastSquaredModel:
    """
    Implementation of an ordinary least squared model (OLS).

    Inputs and outputs are retrieved from a generator. This generator needs to return in xarray format and has to be
    iterable. OLS is calculated on initialisation using statsmodels package. Train your personal OLS using:

    .. code-block:: python

        # next(train_data) should be return (x, y)
        my_ols_model = OrdinaryLeastSquaredModel(train_data)

    After calculation, use your OLS model with

    ..  code-block:: python

        # input_data needs to be structured like train data
        result_ols = my_ols_model.predict(input_data)

    :param generator: generator object returning a tuple containing inputs and outputs as xarrays
    """

    def __init__(self, generator):
        """Set up OLS model."""
        self.x = []
        self.y = []
        self.generator = generator
        self.model = self._train_ols_model_from_generator()

    def _train_ols_model_from_generator(self):
        self._set_x_y_from_generator()
        self.x = sm.add_constant(self.x)
        return self.ordinary_least_squared_model(self.x, self.y)

    def _set_x_y_from_generator(self):
        data_x = None
        data_y = None
        for item in self.generator:
            x = self.reshape_xarray_to_numpy(item[0])
            y = item[1].values
            data_x = np.concatenate((data_x, x), axis=0) if data_x is not None else x
            data_y = np.concatenate((data_y, y), axis=0) if data_y is not None else y
        self.x = data_x
        self.y = data_y

    def predict(self, data):
        """Apply OLS model on data."""
        data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
        return np.atleast_2d(self.model.predict(data))

    @staticmethod
    def reshape_xarray_to_numpy(data):
        """Reshape xarray data to numpy data and flatten."""
        shape = data.values.shape
        res = data.values.reshape(shape[0], shape[1] * shape[3])
        return res

    @staticmethod
    def ordinary_least_squared_model(x, y):
        """Calculate ols model using statsmodels."""
        ols_model = sm.OLS(y, x)
        return ols_model.fit()