# -*- coding: utf-8 -*-
"""
Create, Read, Update, Delete functionality

"""

import sys
from sqlalchemy import insert, delete, select, and_
from sqlalchemy.orm import Session
from sqlalchemy.engine import Engine
from geoalchemy2.elements import WKBElement, WKTElement
from fastapi import File, UploadFile
from fastapi.responses import JSONResponse, FileResponse
from . import models, schemas
from .models import DF_enum
from toardb.variables import models as variables_models
from toardb.stationmeta import models as stationmeta_models
from toardb.stationmeta.crud import get_stationmeta_core
from toardb.stationmeta.schemas import get_coordinates_from_geom, get_geom_from_coordinates
from toardb.timeseries.models import TimeseriesChangelog
from toardb.timeseries.crud import get_timeseries_by_unique_constraints
from toardb.generic.models import CL_enum
from toardb.utils.utils import get_value_from_str, get_str_from_value
import datetime as dt
import pandas as pd
from io import StringIO
import csv

def get_data(db: Session, timeseries_id: int, format: str):
    data = db.query(models.Data).filter(models.Data.timeseries_id == timeseries_id).all()
    if format == 'json':
        return data
    elif format == 'csv':
        outfile = open('mydump.csv', 'w')
        outcsv = csv.writer(outfile)
        # write header line
        outcsv.writerow([column.name for column in models.Data.__mapper__.columns])
        # now the data
        [outcsv.writerow([getattr(curr, column.name) for column in models.Data.__mapper__.columns]) for curr in data]
        # or maybe use outcsv.writerows(data)
        outfile.close()
#       df = pd.read_csv('mydump.csv')
#       print(df)
        return FileResponse('mydump.csv')
    else:
        status_code=400
        message='Invalid format!'
        return JSONResponse(status_code=status_code, content=message)


def get_data_by_datetime_and_timeseriesid(db: Session, datetime: dt.datetime, timeseries_id: int):
    return db.query(models.Data).filter([models.Data.datetime== datetime, models.Data.timeseries_id == timeseries_id]).first()


def get_all_data(db: Session, skip : int = 0, limit: int = 100):
    return db.query(models.Data).limit(limit).all()


def create_data_record(db: Session, engine: Engine,
        series_id: int, datetime: dt.datetime,
        value: float, flag: str, version: str):
    flag_num = get_value_from_str(DF_enum,flag)
    data_dict = {"datetime": datetime,
                 "value": value,
                 "flags": flag_num,
                 "version": version,
                 "timeseries_id": series_id}
    data = models.Data(**data_dict)
    db.add(data)
    result = db.commit()
    db.refresh(data)
    status_code=200
    message='Data successfully inserted!'
    return JSONResponse(status_code=status_code, content=message)


def create_data(db: Session, engine: Engine, input_handle: UploadFile = File(...)):
    # a timeseries is defined by the unique_constraint of (station_id, variable_id, label)
    # station_id: from header
    # variable_id: from database (with variable_name -- from filename)
    # label: unique information for *this* timeseries (station_id, variable_id) -- at the moment: contributor_shortname 
    # get variable_name from filename
    variable_name = input_handle.filename.split('_')[0]
    variable = db.query(variables_models.Variable).filter(variables_models.Variable.name == variable_name).first()
    variable_id = variable.id
    # get header information (station_id, contributor_shortname, timeshift_from_utc)
    line = '#bla'
    f = input_handle.file
    prev = pos = 0
    while line[0] == '#':
        line = f.readline().decode('utf-8')
        key = line.split(':')[0].lower().strip()
        if key == "#station_id":
            station_id = line.split(':')[1]
        if key == "#contributor_shortname":
            label = line.split(':')[1]
        if key == "#timeshift_from_utc":
            timeoffset = dt.timedelta(hours=float(line.split(':')[1]))
        prev, pos = pos, f.tell()
    f.seek(prev)
    station_code = station_id
    stationmeta_core = get_stationmeta_core(db=db,station_code=station_code)
    station_id = stationmeta_core.id
    timeseries = get_timeseries_by_unique_constraints(db=db,station_id=station_id,variable_id=variable_id,label=label)
    version = '000001.000000.00000000000000'
    if timeseries:
        timeseries_id = timeseries.id
        # open SpooledTemporaryFile, skip header (and also try to insert timeseries_id!)
        df = pd.read_csv(input_handle.file, comment='#', header=None, sep=';',names=["time","value","flags"],parse_dates=["time"],index_col="time")
        # substract timeshift to convert data to UTC
        df.index = df.index - timeoffset
        # now insert the timeseries_id to the end of the data frame
        df.insert(2, 'timeseries_id', timeseries_id)
        # also insert version
        df.insert(3, 'version', version)
        # datetime needs timezone information
        df = df.tz_localize('UTC')
        buf = StringIO()
        df.to_csv(buf, header=False)
        buf.pos = 0
        buf.seek(0)
        fake_conn = engine.raw_connection()
        fake_cur = fake_conn.cursor()
        try:
            fake_cur.copy_from(buf, 'data', sep=',', columns=('datetime','value','flags','timeseries_id', 'version'))
            fake_conn.commit()
            message = 'Data successfully inserted.'
            status_code = 200
        except:
            e = sys.exc_info()[0]
            message = 'An error occurred in data insertion: %s' % (e,)
            status_code = 400
    else:
        message = f'Timeseries not found for station {station_code.strip()}, variable {variable_name}, label {label.strip()}'
        status_code = 400
    return JSONResponse(status_code=status_code, content=message)

def patch_data(db: Session, engine: Engine, description: str, version: str, input_handle: UploadFile = File(...)):
    # a timeseries is defined by the unique_constraint of (station_id, variable_id, label)
    # station_id: from header
    # variable_id: from database (with variable_name -- from filename)
    # label: unique information for *this* timeseries (station_id, variable_id) -- at the moment: contributor_shortname 
    # get variable_name from filename

    # versionlabel has to be unique for this timeseries ==> to be checked!

    variable_name = input_handle.filename.split('_')[0]
    variable = db.query(variables_models.Variable).filter(variables_models.Variable.name == variable_name).first()
    variable_id = variable.id
    # get header information (station_id, contributor_shortname, timeshift_from_utc)
    line = '#bla'
    f = input_handle.file
    prev = pos = 0
    while line[0] == '#':
        line = f.readline().decode('utf-8')
        key = line.split(':')[0].lower().strip()
        if key == "#station_id":
            station_id = line.split(':')[1]
        if key == "#contributor_shortname":
            label = line.split(':')[1]
        if key == "#timeshift_from_utc":
            timeoffset = dt.timedelta(hours=float(line.split(':')[1]))
        prev, pos = pos, f.tell()
    f.seek(prev)
    station_code = station_id
    stationmeta_core = get_stationmeta_core(db=db,station_code=station_code)
    # there is a mismatch with coordinates
    stationmeta_core.coordinates = get_geom_from_coordinates(stationmeta_core.coordinates)
    station_id = stationmeta_core.id
    timeseries = get_timeseries_by_unique_constraints(db=db,station_id=station_id,variable_id=variable_id,label=label)
    if timeseries:
        timeseries_id = timeseries.id
        # open SpooledTemporaryFile, skip header (and also try to insert timeseries_id!)
        df = pd.read_csv(input_handle.file, comment='#', header=None, sep=';',names=["time","value","flags"],parse_dates=["time"],index_col="time")
        # substract timeshift to convert data to UTC
        df.index = df.index - timeoffset
        # now insert the timeseries_id to the end of the data frame
        df.insert(2, 'timeseries_id', timeseries_id)
        # also insert version
        df.insert(3, 'version', version)
        # datetime needs timezone information
        df = df.tz_localize('UTC')
        # determine period_start and period_end of data
        period_start = min(df.index)
        period_end = max(df.index)
        # mv data from this period to data_archive
        # the following command will not work, because only one record is intended to be inserted
        # db.execute(insert(models.DataArchive).from_select([models.DataArchive],
        #                                            select([models.Data]).where(
        #     and_(and_(models.Data.timeseries_id == timeseries_id,
        #               models.Data.datetime >= period_start),
        #              models.Data.datetime <= period_end))))
        # debug: check, whether where statement is correctly executed
        result = db.execute(select([models.Data]).where(
            and_(and_(models.Data.timeseries_id == timeseries_id,
                      models.Data.datetime >= period_start),
                      models.Data.datetime <= period_end)))
        rows = result.fetchall()
        for row in rows:
            db_obj = models.DataArchive(datetime=row[0], value=row[1], flags=row[2], version=row[3], timeseries_id=row[4])
            db.add(db_obj)
            db.commit()
        db.execute(delete(models.Data).where(
            and_(and_(models.Data.timeseries_id == timeseries_id,
                      models.Data.datetime >= period_start),
                      models.Data.datetime <= period_end)))
        db.commit()
        # now insert new data for this period from file
        buf = StringIO()
        df.to_csv(buf, header=False)
        buf.pos = 0
        buf.seek(0)
        fake_conn = engine.raw_connection()
        fake_cur = fake_conn.cursor()
        try:
            fake_cur.copy_from(buf, 'data', sep=',', columns=('datetime','value','flags','timeseries_id', 'version'))
            fake_conn.commit()
            message = 'Data successfully inserted.'
            status_code = 200
        except:
            e = sys.exc_info()[0]
            message = 'An error occurred in data insertion: %s' % (e,)
            status_code = 400
        # create changelog entry
        # how to determine type_of_change?
        # 4 – unspecified data value corrections (this holds also, if there is only one single value to be corrected; the addition "unspecified" keeps all possibilities open to add "specified" corrections later (e. g. from QC)
        # 5 – replaced data with a new version
        type_of_change = get_value_from_str(CL_enum,"UnspecifiedData")
        db_changelog = TimeseriesChangelog(description=description, timeseries_id=timeseries_id, author_id=1, type_of_change=type_of_change,
                                            old_value="", new_value="", period_start=period_start, period_end=period_end, version=version)
        db.add(db_changelog)
        db.commit()
    else:
        message = f'Timeseries not found for station {station_code.strip()}, variable {variable_name}, label {label.strip()}'
        status_code = 400
    return JSONResponse(status_code=status_code, content=message)