import datetime as dt
import os
from operator import itemgetter
import logging

import numpy as np
import pandas as pd
import pytest
import xarray as xr

from src.data_handling.data_preparation import DataPrep
from src.join import EmptyQueryResult


class TestDataPrep:

    @pytest.fixture
    def data(self):
        return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                        station_type='background', test='testKWARGS',
                        statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})

    @pytest.fixture
    def data_prep_no_init(self):
        d = object.__new__(DataPrep)
        d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
        d.network = 'UBA'
        d.station = ['DEBW107']
        d.variables = ['o3', 'temp']
        d.station_type = "background"
        d.sampling = "daily"
        d.kwargs = None
        return d

    def test_init(self, data):
        assert data.path == os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
        assert data.network == 'AIRBASE'
        assert data.station == ['DEBW107']
        assert data.variables == ['o3', 'temp']
        assert data.station_type == "background"
        assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'}
        assert not all([data.mean, data.std, data.history, data.label, data.station_type])
        assert {'test': 'testKWARGS'}.items() <= data.kwargs.items()

    def test_init_no_stats(self):
        with pytest.raises(NotImplementedError):
            DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'])

    def test_download_data(self, data_prep_no_init):
        file_name = data_prep_no_init._set_file_name()
        meta_file = data_prep_no_init._set_meta_file_name()
        data_prep_no_init.kwargs = {"store_data_locally": False}
        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
        data_prep_no_init.download_data(file_name, meta_file)
        assert isinstance(data_prep_no_init.data, xr.DataArray)

    def test_download_data_from_join(self, data_prep_no_init):
        file_name = data_prep_no_init._set_file_name()
        meta_file = data_prep_no_init._set_meta_file_name()
        data_prep_no_init.kwargs = {"store_data_locally": False}
        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
        xarr, meta = data_prep_no_init.download_data_from_join(file_name, meta_file)
        assert isinstance(xarr, xr.DataArray)
        assert isinstance(meta, pd.DataFrame)

    def test_check_station_meta(self, caplog, data_prep_no_init):
        caplog.set_level(logging.DEBUG)
        file_name = data_prep_no_init._set_file_name()
        meta_file = data_prep_no_init._set_meta_file_name()
        data_prep_no_init.kwargs = {"store_data_locally": False}
        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
        data_prep_no_init.download_data(file_name, meta_file)
        assert data_prep_no_init.check_station_meta() is None
        data_prep_no_init.station_type = "traffic"
        with pytest.raises(FileNotFoundError) as e:
            data_prep_no_init.check_station_meta()
        msg = "meta data does not agree with given request for station_type: traffic (requested) != background (local)"
        assert caplog.record_tuples[-1][:-1] == ('root', 10)
        assert msg in caplog.record_tuples[-1][-1]

    def test_load_data_overwrite_local_data(self, data_prep_no_init):
        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
        file_path = data_prep_no_init._set_file_name()
        meta_file_path = data_prep_no_init._set_meta_file_name()
        os.remove(file_path)
        os.remove(meta_file_path)
        assert not os.path.exists(file_path)
        assert not os.path.exists(meta_file_path)
        data_prep_no_init.kwargs = {"overwrite_local_data": True}
        data_prep_no_init.load_data()
        assert os.path.exists(file_path)
        assert os.path.exists(meta_file_path)
        t = os.stat(file_path).st_ctime
        tm = os.stat(meta_file_path).st_ctime
        data_prep_no_init.load_data()
        assert os.path.exists(file_path)
        assert os.path.exists(meta_file_path)
        assert os.stat(file_path).st_ctime > t
        assert os.stat(meta_file_path).st_ctime > tm
        assert isinstance(data_prep_no_init.data, xr.DataArray)
        assert isinstance(data_prep_no_init.meta, pd.DataFrame)

    def test_load_data_keep_local_data(self, data_prep_no_init):
        data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'}
        data_prep_no_init.station_type = None
        data_prep_no_init.kwargs = {}
        file_path = data_prep_no_init._set_file_name()
        data_prep_no_init.load_data()
        assert os.path.exists(file_path)
        t = os.stat(file_path).st_ctime
        data_prep_no_init.load_data()
        assert os.path.exists(data_prep_no_init._set_file_name())
        assert os.stat(file_path).st_ctime == t
        assert isinstance(data_prep_no_init.data, xr.DataArray)
        assert isinstance(data_prep_no_init.meta, pd.DataFrame)

    def test_repr(self, data_prep_no_init):
        path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
        assert data_prep_no_init.__repr__().rstrip() == f"Dataprep(path='{path}', network='UBA', " \
                                                        f"station=['DEBW107'], variables=['o3', 'temp'], " \
                                                        f"station_type=background, **None)".rstrip()

    def test_set_file_name_and_meta(self):
        d = object.__new__(DataPrep)
        d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "data")
        d.station = 'TESTSTATION'
        d.variables = ['a', 'bc']
        assert d._set_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)),
                                                  "data/TESTSTATION_a_bc.nc")
        assert d._set_meta_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)),
                                                       "data/TESTSTATION_a_bc_meta.csv")

    @pytest.mark.parametrize('opts', [{'dim': 'datetime', 'method': 'nearest', 'limit': 10, 'use_coordinate': True},
                                      {'dim': 'datetime', 'limit': 5}, {'dim': 'datetime'}])
    def test_interpolate(self, data, opts):
        data_org = data.data
        data.interpolate(**opts)
        # set default params if empty
        opts["method"] = opts.get("method", 'linear')
        opts["limit"] = opts.get("limit", None)
        opts["use_coordinate"] = opts.get("use_coordinate", True)
        assert xr.testing.assert_equal(data_org.interpolate_na(**opts), data.data) is None

    def test_transform_standardise(self, data):
        assert data._transform_method is None
        assert data.mean is None
        assert data.std is None
        data.transform('datetime')
        assert data._transform_method == 'standardise'
        assert np.testing.assert_almost_equal(data.data.mean('datetime').variable.values, np.array([[0, 0]])) is None
        assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, np.array([[1, 1]])) is None
        assert isinstance(data.mean, xr.DataArray)
        assert isinstance(data.std, xr.DataArray)

    @pytest.mark.parametrize('mean, std, method, msg', [(10, 3, 'standardise', ''), (6, None, 'standardise', 'std, '),
                                                        (None, 3, 'standardise', 'mean, '), (19, None, 'centre', ''),
                                                        (None, 2, 'centre', 'mean, '), (8, 2, 'centre', ''),
                                                        (None, None, 'standardise', 'mean, std, ')])
    def test_check_inverse_transform_params(self, data, mean, std, method, msg):
        if len(msg) > 0:
            with pytest.raises(AttributeError) as e:
                data.check_inverse_transform_params(mean, std, method)
            assert msg in e.value.args[0]
        else:
            assert data.check_inverse_transform_params(mean, std, method) is None

    def test_transform_centre(self, data):
        assert data._transform_method is None
        assert data.mean is None
        assert data.std is None
        data_std_org = data.data.std('datetime'). variable.values
        data.transform('datetime', 'centre')
        assert data._transform_method == 'centre'
        assert np.testing.assert_almost_equal(data.data.mean('datetime').variable.values, np.array([[0, 0]])) is None
        assert np.testing.assert_almost_equal(data.data.std('datetime').variable.values, data_std_org) is None
        assert data.std is None

    @pytest.mark.parametrize('method', ['standardise', 'centre'])
    def test_transform_inverse(self, data, method):
        data_org = data.data
        data.transform('datetime', method)
        data.inverse_transform()
        assert data._transform_method is None
        assert data.mean is None
        assert data.std is None
        assert np.testing.assert_array_almost_equal(data_org, data.data) is None
        data.transform('datetime', method)
        data.transform('datetime', inverse=True)
        assert data._transform_method is None
        assert data.mean is None
        assert data.std is None
        assert np.testing.assert_array_almost_equal(data_org, data.data) is None

    @pytest.mark.parametrize('method', ['normalise', 'unknownmethod'])
    def test_transform_errors(self, data, method):
        with pytest.raises(NotImplementedError):
            data.transform('datetime', method)
        data._transform_method = method
        with pytest.raises(AssertionError) as e:
            data.transform('datetime', method)
        assert "Transform method is already set." in e.value.args[0]

    @pytest.mark.parametrize('method', ['normalise', 'unknownmethod'])
    def test_transform_inverse_errors(self, data, method):
        with pytest.raises(AssertionError) as e:
            data.inverse_transform()
        assert "Inverse transformation method is not set." in e.value.args[0]
        data.mean = 1
        data.std = 1
        data._transform_method = method
        with pytest.raises(NotImplementedError):
            data.inverse_transform()

    def test_get_transformation_information(self, data):
        assert (None, None, None) == data.get_transformation_information("o3")
        mean_test = data.data.mean("datetime").sel(variables='o3').values
        std_test = data.data.std("datetime").sel(variables='o3').values
        data.transform('datetime')
        mean, std, info = data.get_transformation_information("o3")
        assert np.testing.assert_almost_equal(mean, mean_test) is None
        assert np.testing.assert_almost_equal(std, std_test) is None
        assert info == "standardise"

    def test_nan_remove_no_hist_or_label(self, data):
        assert data.history is None
        assert data.label is None
        data.history_label_nan_remove('datetime')
        assert data.history is None
        assert data.label is None
        data.make_history_window('datetime', 6)
        assert data.history is not None
        data.history_label_nan_remove('datetime')
        assert data.history is None
        data.make_labels('variables', 'o3', 'datetime', 2)
        assert data.label is not None
        data.history_label_nan_remove('datetime')
        assert data.label is None

    def test_nan_remove(self, data):
        data.make_history_window('datetime', -12)
        data.make_labels('variables', 'o3', 'datetime', 3)
        shape = data.history.shape
        data.history_label_nan_remove('datetime')
        assert data.history.isnull().sum() == 0
        assert itemgetter(0, 1, 3)(shape) == itemgetter(0, 1, 3)(data.history.shape)
        assert shape[2] >= data.history.shape[2]

    def test_create_index_array(self, data):
        index_array = data.create_index_array('window', range(1, 4))
        assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None
        assert index_array.name == 'window'
        assert index_array.coords.dims == ('window', )
        index_array = data.create_index_array('window', range(0, 1))
        assert np.testing.assert_array_equal(index_array.data, [0]) is None
        assert index_array.name == 'window'
        assert index_array.coords.dims == ('window', )

    @staticmethod
    def extract_window_data(res, orig, w):
        slice = {'variables': ['temp'], 'Stations': 'DEBW107', 'datetime': dt.datetime(1997, 1, 6)}
        window = res.sel(slice).data.flatten()
        if w <= 0:
            delta = w
            w = abs(w)+1
        else:
            delta = 1
        slice = {'variables': ['temp'], 'Stations': 'DEBW107',
                 'datetime': pd.date_range(dt.date(1997, 1, 6) + dt.timedelta(days=delta), periods=w, freq='D')}
        orig_slice = orig.sel(slice).data.flatten()
        return window, orig_slice

    def test_shift(self, data):
        res = data.shift('datetime', 4)
        window, orig = self.extract_window_data(res, data.data, 4)
        assert res.coords.dims == ('window', 'Stations', 'datetime', 'variables')
        assert list(res.data.shape) == [4] + list(data.data.shape)
        assert np.testing.assert_array_equal(orig, window) is None
        res = data.shift('datetime', -3)
        window, orig = self.extract_window_data(res, data.data, -3)
        assert list(res.data.shape) == [4] + list(data.data.shape)
        assert np.testing.assert_array_equal(orig, window) is None
        res = data.shift('datetime', 0)
        window, orig = self.extract_window_data(res, data.data, 0)
        assert list(res.data.shape) == [1] + list(data.data.shape)
        assert np.testing.assert_array_equal(orig, window) is None

    def test_make_history_window(self, data):
        assert data.history is None
        data.make_history_window('datetime', 5)
        assert data.history is not None
        save_history = data.history
        data.make_history_window('datetime', -5)
        assert np.testing.assert_array_equal(data.history, save_history) is None

    def test_make_labels(self, data):
        assert data.label is None
        data.make_labels('variables', 'o3', 'datetime', 3)
        assert data.label.variables.data == 'o3'
        assert list(data.label.shape) == [3] + list(data.data.shape)[:2]
        save_label = data.label
        data.make_labels('variables', 'o3', 'datetime', -3)
        assert np.testing.assert_array_equal(data.label, save_label) is None

    def test_slice(self, data):
        res = data._slice(data.data, dt.date(1997, 1, 1), dt.date(1997, 1, 10), 'datetime')
        assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data.data.shape)
        assert res.shape[1] == 10

    def test_slice_prep(self, data):
        res = data._slice_prep(data.data)
        assert res.shape == data.data.shape
        data.kwargs['start'] = res.coords['datetime'][0].values
        data.kwargs['end'] = res.coords['datetime'][9].values
        res = data._slice_prep(data.data)
        assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data.data.shape)
        assert res.shape[1] == 10

    def test_check_for_neg_concentrations(self, data):
        res = data.check_for_negative_concentrations(data.data)
        assert res.sel({'variables': 'o3'}).min() >= 0
        res = data.check_for_negative_concentrations(data.data, minimum=2)
        assert res.sel({'variables': 'o3'}).min() >= 2

    def test_check_station(self, data):
        with pytest.raises(EmptyQueryResult):
            data_new = DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'],
                                station_type='traffic', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})