#taken from: https://github.com/tiangolo/fastapi/issues/831
import pytest
from starlette.testclient import TestClient
from typing import Optional, AsyncIterable
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine as Database
from sqlalchemy.orm import Session
from sqlalchemy_utils import database_exists, create_database, drop_database

from toardb.utils.database import DATABASE_URL
from .models import Base
from toardb.auth_user.models import Base as AuthUserBase
from toardb.contacts.models import Base as ContactBase
from toardb.toardb import app
from toardb.stationmeta.stationmeta import get_db

url = str(DATABASE_URL+ "_test")
_db_conn = create_engine(url)

def get_test_db_conn() -> Database:
    assert _db_conn is not None
    return _db_conn


def get_test_db() -> AsyncIterable[Session]:
    sess = Session(bind=_db_conn)

    try:
        yield sess
    finally:
        sess.close()


@pytest.fixture(scope="session", autouse=True)
def create_test_database():
    """
    Create a clean database on every test case.
    We use the `sqlalchemy_utils` package here for a few helpers in consistently
    creating and dropping the database.
    """
    if database_exists(url):
        drop_database(url)
    create_database(url)  # Create the test database.
    #'create_all' does not work (because of undefined 'Geometry')!
    #declare PostGIS extension!
    fake_conn = _db_conn.raw_connection()
    fake_cur = fake_conn.cursor()
    fake_cur.execute("CREATE EXTENSION postgis")
    fake_conn.commit()
    # I also need the id_seq tables:
    fake_cur.execute("CREATE SEQUENCE public.auth_user_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_core_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_global_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_global_services_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_annotations_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_roles_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_docs_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_images_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_urls_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.persons_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    fake_cur.execute("CREATE SEQUENCE public.organisations_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1")
    fake_conn.commit()
    AuthUserBase.metadata.create_all(_db_conn)  # Create the tables.
    ContactBase.metadata.create_all(_db_conn)  # Create the tables.
    Base.metadata.create_all(_db_conn)  # Create the tables.
    #try with the basics
    app.dependency_overrides[get_db] = get_test_db  # Mock the Database Dependency
    yield  # Run the tests.
    drop_database(url)  # Drop the test database.


@pytest.yield_fixture
def test_db_session():
    """Returns an sqlalchemy session, and after the test tears down everything properly."""

    session = Session(bind=_db_conn)

    yield session
    # Drop all data after each test
    for tbl in reversed(ContactBase.metadata.sorted_tables + AuthUserBase.metadata.sorted_tables + Base.metadata.sorted_tables):
        _db_conn.execute(tbl.delete())
    # put back the connection to the connection pool
    session.close()


@pytest.fixture()
def client():
    """
    When using the 'client' fixture in test cases, we'll get full database
    rollbacks between test cases:
    """
    with TestClient(app) as client:
        yield client