#taken from: https://github.com/tiangolo/fastapi/issues/831
import pytest
from starlette.testclient import TestClient
from typing import Optional, AsyncIterable
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine as Database
from sqlalchemy.orm import Session
from sqlalchemy_utils import database_exists, create_database, drop_database

from toardb.utils.database import DATABASE_URL
from .models import Base
from toardb.variables.models import Base as VariableBase
from toardb.stationmeta.models import Base as StationmetaBase
from toardb.auth_user.models import Base as AuthUserBase
from toardb.contacts.models import Base as ContactBase
from toardb.toardb import app
from toardb.timeseries.timeseries import get_db
 
url = str(DATABASE_URL+ "_test")
_db_conn = create_engine(url)

def get_test_db_conn() -> Database:
    assert _db_conn is not None
    return _db_conn


def get_test_db() -> AsyncIterable[Session]:
    sess = Session(bind=_db_conn)

    try:
        yield sess
    finally:
        sess.close()


@pytest.fixture(scope="session", autouse=True)
def create_test_database():
    """
    Create a clean database on every test case.
    We use the `sqlalchemy_utils` package here for a few helpers in consistently
    creating and dropping the database.
    """
    if database_exists(url):
        drop_database(url)
    create_database(url)  # Create the test database.
    #'create_all' does not work (because of undefined 'Geometry')!
    #declare PostGIS extension!
    fake_conn = _db_conn.raw_connection()
    fake_cur = fake_conn.cursor()
    fake_cur.execute("CREATE EXTENSION postgis")
    fake_conn.commit()
    # I also need the id_seq tables:
    fake_cur.execute("CREATE SEQUENCE public.auth_user_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.variables_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_core_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_global_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_global_services_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_roles_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_annotations_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_docs_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_images_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_urls_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.persons_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.organisations_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.timeseries_annotations_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.timeseries_roles_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.timeseries_programmes_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.timeseries_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    AuthUserBase.metadata.create_all(_db_conn)  # Create the tables.
    ContactBase.metadata.create_all(_db_conn)  # Create the tables.
    VariableBase.metadata.create_all(_db_conn)  # Create the tables.
    StationmetaBase.metadata.create_all(_db_conn)  # Create the tables.
    # just to show that tables are available now:
#   fake_cur.execute("SELECT * FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
#   rows = fake_cur.fetchall()
#   for row in rows:
#       print(row)
    Base.metadata.create_all(_db_conn)  # Create the tables.
    #try with the basics
    app.dependency_overrides[get_db] = get_test_db  # Mock the Database Dependency
    yield  # Run the tests.
    drop_database(url)  # Drop the test database.


@pytest.yield_fixture
def test_db_session():
    """Returns an sqlalchemy session, and after the test tears down everything properly."""

    session = Session(bind=_db_conn)

    yield session
    # Drop all data after each test
    for tbl in reversed(AuthUserBase.metadata.sorted_tables +  
                        ContactBase.metadata.sorted_tables +  
                        VariableBase.metadata.sorted_tables +  
                        StationmetaBase.metadata.sorted_tables +   
                        Base.metadata.sorted_tables):
        _db_conn.execute(tbl.delete())
    # put back the connection to the connection pool
    session.close()


@pytest.fixture()
def client():
    """
    When using the 'client' fixture in test cases, we'll get full database
    rollbacks between test cases:
    """
    with TestClient(app) as client:
        yield client