# -*- coding: utf-8 -*-
"""
Create, Read, Update, Delete functionality

"""

from sqlalchemy import insert
from sqlalchemy.orm import Session
from geoalchemy2.elements import WKBElement, WKTElement
from fastapi.responses import JSONResponse
from . import models
from .models import TimeseriesChangelog, timeseries_timeseries_roles_table, \
                    timeseries_timeseries_annotations_table, \
                    DA_enum, SF_enum, AT_enum, DS_enum, MM_enum
from toardb.stationmeta.models import StationmetaCore
from toardb.stationmeta.schemas import get_coordinates_from_geom, get_geom_from_coordinates
from toardb.generic.models import RS_enum, RC_enum
from .schemas import TimeseriesCreate, TimeseriesPatch
from toardb.utils.utils import get_value_from_str, get_str_from_value


def get_timeseries(db: Session, timeseries_id: int):
    db_object = db.query(models.Timeseries).filter(models.Timeseries.id == timeseries_id).first()
    # there is a mismatch with additional_metadata
    if db_object:
        db_object.additional_metadata = str(db_object.additional_metadata).replace("'",'"')
        # there is also a mismatch with coordinates and additional_metadata from station object
        if isinstance(db_object.station.coordinates, (WKBElement, WKTElement)):
            db_object.station.coordinates = get_coordinates_from_geom(db_object.station.coordinates)
        db_object.station.additional_metadata = str(db_object.station.additional_metadata).replace("'",'"')
    return db_object


def get_all_timeseries(db: Session, skip : int = 0, limit: int = None):
    db_objects = db.query(models.Timeseries).offset(skip).limit(limit).all()
    for db_object in db_objects:
        # there is a mismatch with additional_metadata
        db_object.additional_metadata = str(db_object.additional_metadata).replace("'",'"')
        # there is also a mismatch with coordinates and additional_metadata from station object
        if isinstance(db_object.station.coordinates, (WKBElement, WKTElement)):
            db_object.station.coordinates = get_coordinates_from_geom(db_object.station.coordinates)
        db_object.station.additional_metadata = str(db_object.station.additional_metadata).replace("'",'"')
    return db_objects


def get_timeseries_by_unique_constraints(db: Session, station_id: int, variable_id: int, label: str):
    db_object = db.query(models.Timeseries).filter(models.Timeseries.station_id == station_id) \
                                      .filter(models.Timeseries.variable_id == variable_id) \
                                      .filter(models.Timeseries.label == label.strip()) \
                                      .first()
    # there is a mismatch with additional_metadata
    if db_object:
        db_object.additional_metadata = str(db_object.additional_metadata).replace("'",'"')
        # there is also a mismatch with coordinates and additional_metadata from station object
        if isinstance(db_object.station.coordinates, (WKBElement, WKTElement)):
          db_object.station.coordinates = get_coordinates_from_geom(db_object.station.coordinates)
        db_object.station.additional_metadata = str(db_object.station.additional_metadata).replace("'",'"')
    return db_object


# is this internal, or should this also go to public REST api?
# do we need this at all?
def get_role_ids_of_timeseries(db: Session, timeseries_id: int):
    db_objects = db.query(models.TimeseriesTimeseriesRoles) \
                                      .filter(models.TimeseriesTimeseriesRoles.timeseries_id == timeseries_id) \
                                      .all()
    return db_objects


# is this internal, or should this also go to public REST api?
def get_unique_timeseries_role(db: Session, role: int, contact_id: int, status: int):
    db_object = db.query(models.TimeseriesRole).filter(models.TimeseriesRole.role == role) \
                                      .filter(models.TimeseriesRole.contact_id == contact_id) \
                                      .filter(models.TimeseriesRole.status == status) \
                                      .first()
    return db_object

# is this internal, or should this also go to public REST api?
def get_unique_timeseries_programme(db: Session, name: str, homepage: str):
    db_object = db.query(models.TimeseriesProgramme).filter(models.TimeseriesProgramme.name == name) \
                                      .filter(models.TimeseriesProgramme.homepage == homepage) \
                                      .first()
    return db_object


# is this internal, or should this also go to public REST api?
def get_unique_timeseries_annotation(db: Session, text: str, contributor_id: int):
    db_object = db.query(models.TimeseriesAnnotation).filter(models.TimeseriesAnnotation.text == text) \
                                      .filter(models.TimeseriesAnnotation.contributor_id == contributor_id) \
                                      .first()
    return db_object


def create_timeseries(db: Session, timeseries: TimeseriesCreate):
    timeseries_dict = timeseries.dict()
    roles_data = timeseries_dict.pop('roles', None)
    annotations_data = timeseries_dict.pop('annotations', None)
    db_timeseries = models.Timeseries(**timeseries_dict)
    db_timeseries.access_rights = get_value_from_str(DA_enum,db_timeseries.access_rights)
    db_timeseries.sampling_frequency = get_value_from_str(SF_enum,db_timeseries.sampling_frequency)
    db_timeseries.aggregation = get_value_from_str(AT_enum,db_timeseries.aggregation)
    db_timeseries.source = get_value_from_str(DS_enum,db_timeseries.source)
    db_timeseries.measurement_method= get_value_from_str(MM_enum,db_timeseries.measurement_method)
    db.add(db_timeseries)
    result = db.commit()
    db.refresh(db_timeseries)
    # get timeseries_id
    timeseries_id = db_timeseries.id
    # store roles and update association table
    if roles_data:
        for r in roles_data:
            db_role = models.TimeseriesRole(**r)
            db_role.role = get_value_from_str(RC_enum,db_role.role)
            db_role.status = get_value_from_str(RS_enum,db_role.status)
            # check whether role is already present in database
            db_object = get_unique_timeseries_role(db, db_role.role, db_role.contact_id, db_role.status)
            if db_object:
                role_id = db_object.id
            else:
                db.add(db_role)
                db.commit()
                db.refresh(db_role)
                role_id = db_role.id
            db.execute(insert(timeseries_timeseries_roles_table).values(timeseries_id=timeseries_id, role_id=role_id))
            db.commit()
    # store annotations and update association table
    if annotations_data:
        for a in annotations_data:
            db_annotation = models.TimeseriesAnnotation(**a)
            # check whether annotation is already present in database
            db_object = get_unique_timeseries_annotation(db, db_annotation.text, db_annotation.contributor_id)
            if db_object:
                annotation_id = db_object.id
            else:
                db.add(db_annotation)
                db.commit()
                db.refresh(db_annotation)
                annotation_id = db_annotation.id
            db.execute(insert(timeseries_timeseries_annotations_table).values(timeseries_id=timeseries_id, annotation_id=annotation_id))
            db.commit()
    # there is a mismatch with additional_metadata
    # in upload command, we have now: "additional_metadata": "{}"
    # but return from this method gives (=database): "additional_metadata": {}
    db_timeseries.additional_metadata = str(db_timeseries.additional_metadata).replace("'",'"')
    return db_timeseries


def patch_timeseries(db: Session, description: str, timeseries_id: int, timeseries: TimeseriesPatch):
    timeseries_dict = timeseries.dict()
    roles_data = timeseries_dict.pop('roles', None)
    annotations_data = timeseries_dict.pop('annotations', None)
    db_timeseries = models.Timeseries(**timeseries_dict)
    # prepare changelog entry/entries
    db_changelog = TimeseriesChangelog(description=description, timeseries_id=timeseries_id, author_id=1, type_of_change=1)
    db_obj = models.Timeseries(**timeseries_dict)
    db_timeseries = db.query(models.Timeseries).get(timeseries_id)
    for k, v in timeseries_dict.items():
        if v is not None:
            db_changelog.old_value=str(getattr(db_timeseries,k))
            setattr(db_timeseries,k,timeseries_dict[k])
            db_changelog.new_value=str(getattr(db_timeseries,k))
    # problems with coordinates...
    db_stationmeta = db.query(StationmetaCore).get(db_timeseries.station_id)
    tmp_coordinates = db_stationmeta.coordinates
    db_stationmeta.coordinates = get_geom_from_coordinates(db_stationmeta.coordinates)
    result = db.commit()
    db.refresh(db_timeseries)
    # store roles and update association table
    if roles_data:
        for r in roles_data:
            db_role = models.TimeseriesRole(**r)
            db_role.role = get_value_from_str(RC_enum,db_role.role)
            db_role.status = get_value_from_str(RS_enum,db_role.status)
            # check whether role is already present in database
            db_object = get_unique_timeseries_role(db, db_role.role, db_role.contact_id, db_role.status)
            if db_object:
                role_id = db_object.id
            else:
                db.add(db_role)
                db.commit()
                db.refresh(db_role)
                role_id = db_role.id
            db.execute(insert(timeseries_timeseries_roles_table).values(timeseries_id=timeseries_id, role_id=role_id))
            db.commit()
    # store annotations and update association table
    if annotations_data:
        for a in annotations_data:
            db_annotation = models.TimeseriesAnnotation(**a)
            # check whether annotation is already present in database
            db_object = get_unique_timeseries_annotation(db, db_annotation.text, db_annotation.contributor_id)
            if db_object:
                annotation_id = db_object.id
            else:
                db.add(db_annotation)
                db.commit()
                db.refresh(db_annotation)
                annotation_id = db_annotation.id
            db.execute(insert(timeseries_timeseries_annotations_table).values(timeseries_id=timeseries_id, annotation_id=annotation_id))
            db.commit()
    # add patch to changelog table
    db.add(db_changelog)
    db.commit()
    # there's a mismatch with coordinates --> how to automatically switch back and forth?!
    db_stationmeta.coordinates = tmp_coordinates
    # there is a mismatch with additional_metadata
    # in upload command, we have now: "additional_metadata": "{}"
    # but return from this method gives (=database): "additional_metadata": {}
    db_timeseries.additional_metadata = str(db_timeseries.additional_metadata).replace("'",'"')
    return db_timeseries