# -*- coding: utf-8 -*-
"""
Create, Read, Update, Delete functionality

"""

from sqlalchemy import insert
from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse
from . import models
from .models import timeseries_timeseries_roles_table, \
                    timeseries_timeseries_annotations_table, timeseries_timeseries_programmes_table, \
                    DA_enum, SF_enum, AT_enum, DS_enum
from toardb.generic.models import RS_enum, RC_enum
from .schemas import TimeseriesCreate
from toardb.utils.utils import get_value_from_str, get_str_from_value


def get_timeseries(db: Session, timeseries_id: int):
    db_object = db.query(models.Timeseries).filter(models.Timeseries.id == timeseries_id).first()
    # there is a mismatch with additional_metadata
    if db_object:
        db_object.additional_metadata = str(db_object.additional_metadata)
        db_object.access_rights = get_str_from_value(db,DA_enum,db_object.access_rights)
        db_object.sampling_frequency = get_str_from_value(db,SF_enum,db_object.sampling_frequency)
        db_object.aggregation = get_str_from_value(db,AT_enum,db_object.aggregation)
        db_object.source = get_str_from_value(db,DS_enum,db_object.source)
    return db_object


def get_all_timeseries(db: Session, skip : int = 0, limit: int = None):
    db_objects = db.query(models.Timeseries).offset(skip).limit(limit).all()
    for db_object in db_objects:
        # there is a mismatch with additional_metadata
        db_object.additional_metadata = str(db_object.additional_metadata)
        db_object.access_rights = get_str_from_value(db,DA_enum,db_object.access_rights)
        db_object.sampling_frequency = get_str_from_value(db,SF_enum,db_object.sampling_frequency)
        db_object.aggregation = get_str_from_value(db,AT_enum,db_object.aggregation)
        db_object.source = get_str_from_value(db,DS_enum,db_object.source)
    return db_objects


def get_timeseries_by_unique_constraints(db: Session, station_id: int, variable_id: int, label: str):
    db_object = db.query(models.Timeseries).filter(models.Timeseries.station_id == station_id) \
                                      .filter(models.Timeseries.variable_id == variable_id) \
                                      .filter(models.Timeseries.label == label.strip()) \
                                      .first()
    # there is a mismatch with additional_metadata
    if db_object:
        db_object.additional_metadata = str(db_object.additional_metadata)
        db_object.access_rights = get_str_from_value(db,DA_enum,db_object.access_rights)
        db_object.sampling_frequency = get_str_from_value(db,SF_enum,db_object.sampling_frequency)
        db_object.aggregation = get_str_from_value(db,AT_enum,db_object.aggregation)
        db_object.source = get_str_from_value(db,DS_enum,db_object.source)
    return db_object


# is this internal, or should this also go to public REST api?
# do we need this at all?
def get_role_ids_of_timeseries(db: Session, timeseries_id: int):
    db_objects = db.query(models.TimeseriesTimeseriesRoles) \
                                      .filter(models.TimeseriesTimeseriesRoles.timeseries_id == timeseries_id) \
                                      .all()
    return db_objects


# is this internal, or should this also go to public REST api?
def get_unique_timeseries_role(db: Session, role: int, person_id: int, status: int):
    db_object = db.query(models.TimeseriesRole).filter(models.TimeseriesRole.role == role) \
                                      .filter(models.TimeseriesRole.person_id == person_id) \
                                      .filter(models.TimeseriesRole.status == status) \
                                      .first()
    return db_object

# is this internal, or should this also go to public REST api?
def get_unique_timeseries_programme(db: Session, name: str, homepage: str):
    db_object = db.query(models.TimeseriesProgramme).filter(models.TimeseriesProgramme.name == name) \
                                      .filter(models.TimeseriesProgramme.homepage == homepage) \
                                      .first()
    return db_object


# is this internal, or should this also go to public REST api?
def get_unique_timeseries_annotation(db: Session, text: str, contributor_id: int):
    db_object = db.query(models.TimeseriesAnnotation).filter(models.TimeseriesAnnotation.text == text) \
                                      .filter(models.TimeseriesAnnotation.contributor_id == contributor_id) \
                                      .first()
    return db_object


def create_timeseries(db: Session, timeseries: TimeseriesCreate):
    timeseries_dict = timeseries.dict()
    roles_data = timeseries_dict.pop('roles', None)
    annotations_data = timeseries_dict.pop('annotations', None)
    programmes_data = timeseries_dict.pop('programmes', None)
    db_timeseries = models.Timeseries(**timeseries_dict)
    db_timeseries.access_rights = get_value_from_str(db,DA_enum,db_timeseries.access_rights)
    db_timeseries.sampling_frequency = get_value_from_str(db,SF_enum,db_timeseries.sampling_frequency)
    db_timeseries.aggregation = get_value_from_str(db,AT_enum,db_timeseries.aggregation)
    db_timeseries.source = get_value_from_str(db,DS_enum,db_timeseries.source)
    db.add(db_timeseries)
    result = db.commit()
    db.refresh(db_timeseries)
    # get timeseries_id
    timeseries_id = db_timeseries.id
    # store roles and update association table
    if roles_data:
        for r in roles_data:
            db_role = models.TimeseriesRole(**r)
            db_role.role = get_value_from_str(db,RC_enum,db_role.role)
            db_role.status = get_value_from_str(db,RS_enum,db_role.status)
            # check whether role is already present in database
            db_object = get_unique_timeseries_role(db, db_role.role, db_role.person_id, db_role.status)
            if db_object:
                role_id = db_object.id
            else:
                db.add(db_role)
                db.commit()
                db.refresh(db_role)
                role_id = db_role.id
            db.execute(insert(timeseries_timeseries_roles_table).values(timeseries_id=timeseries_id, role_id=role_id))
            db.commit()
    # store annotations and update association table
    if annotations_data:
        for a in annotations_data:
            db_annotation = models.TimeseriesAnnotation(**a)
            # check whether annotation is already present in database
            db_object = get_unique_timeseries_annotation(db, db_annotation.text, db_annotation.contributor_id)
            if db_object:
                annotation_id = db_object.id
            else:
                db.add(db_annotation)
                db.commit()
                db.refresh(db_annotation)
                annotation_id = db_annotation.id
            db.execute(insert(timeseries_timeseries_annotations_table).values(timeseries_id=timeseries_id, annotation_id=annotation_id))
            db.commit()
    # store programmes and update association table
    if programmes_data:
        for r in programmes_data:
            db_programme = models.TimeseriesProgramme(**r)
            # check whether programme is already present in database
            db_object = get_unique_timeseries_programme(db, db_programme.name, db_programme.homepage)
            if db_object:
                programme_id = db_object.id
            else:
                db.add(db_programme)
                db.commit()
                db.refresh(db_programme)
                programme_id = db_programme.id
            db.execute(insert(timeseries_timeseries_programmes_table).values(timeseries_id=timeseries_id, programme_id=programme_id))
            db.commit()
    # there is a mismatch with additional_metadata
    # in upload command, we have now: "additional_metadata": "{}"
    # but return from this method gives (=database): "additional_metadata": {}
    db_timeseries.additional_metadata = str(db_timeseries.additional_metadata)
    db_timeseries.access_rights = get_str_from_value(db,DA_enum,db_timeseries.access_rights)
    db_timeseries.sampling_frequency = get_str_from_value(db,SF_enum,db_timeseries.sampling_frequency)
    db_timeseries.aggregation = get_str_from_value(db,AT_enum,db_timeseries.aggregation)
    db_timeseries.source = get_str_from_value(db,DS_enum,db_timeseries.source)
    return db_timeseries