#taken from: https://github.com/tiangolo/fastapi/issues/831
import pytest
from starlette.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import database_exists, create_database, drop_database

from toardb.base import Base
from toardb.toardb import app
from toardb.utils.database import DATABASE_URL, get_db, get_engine

url = "postgresql://toaradmin:toar2@localhost:5432/toardb_test"
_db_conn = create_engine(url)

def get_test_engine() -> Engine:
    assert _db_conn is not None
    return _db_conn


def get_test_db():
    sess = sessionmaker(bind=_db_conn,autoflush=False,autocommit=False)()
    try:
        yield sess
    finally:
        sess.close()


@pytest.fixture(scope="session", autouse=True)
def create_test_database():
    """
    Create a clean database on every test case.
    We use the `sqlalchemy_utils` package here for a few helpers in consistently
    creating and dropping the database.
    """
    if database_exists(url):
        drop_database(url)
    create_database(url)  # Create the test database.
    #'create_all' does not work (because of undefined 'Geometry')!
    #declare PostGIS extension! (and toar_controlled_vocabulary)
    fake_conn = _db_conn.raw_connection()
    fake_cur = fake_conn.cursor()
    fake_cur.execute("CREATE EXTENSION IF NOT EXISTS postgis")
    fake_conn.commit()
    fake_cur.execute("CREATE EXTENSION IF NOT EXISTS toar_controlled_vocabulary")
    fake_conn.commit()
    Base.metadata.create_all(_db_conn)  # Create the tables.
    #try with the basics
    app.dependency_overrides[get_db] = get_test_db  # Mock the Database Dependency
    app.dependency_overrides[get_engine] = get_test_engine # Mock the Database Dependency
    yield  # Run the tests.
    drop_database(url)  # Drop the test database.


@pytest.yield_fixture
def test_db_session():
    """Returns an sqlalchemy session, and after the test tears down everything properly."""

    session = sessionmaker(bind=_db_conn,autoflush=False,autocommit=False)()
    yield session
    # Drop all data after each test
    for tbl in reversed(Base.metadata.sorted_tables):
    # otherwiese all tables from "toar_controlled_vocabulary" will get lost!
        if not tbl.name.endswith("_vocabulary"):
            _db_conn.execute(tbl.delete())
    # put back the connection to the connection pool
    session.close()


@pytest.fixture()
def client():
    """
    When using the 'client' fixture in test cases, we'll get full database
    rollbacks between test cases:
    """
    with TestClient(app) as client:
        yield client