#taken from: https://github.com/tiangolo/fastapi/issues/831 import pytest from starlette.testclient import TestClient from typing import Optional, AsyncIterable from sqlalchemy import create_engine from sqlalchemy.engine import Engine as Database from sqlalchemy.orm import Session from sqlalchemy_utils import database_exists, create_database, drop_database from toardb.utils.database import DATABASE_URL from .models import Base from toardb.auth_user.models import Base as AuthUserBase from toardb.contacts.models import Base as ContactBase from toardb.toardb import app from toardb.stationmeta.stationmeta import get_db url = str(DATABASE_URL+ "_test") _db_conn = create_engine(url) def get_test_db_conn() -> Database: assert _db_conn is not None return _db_conn def get_test_db() -> AsyncIterable[Session]: sess = Session(bind=_db_conn) try: yield sess finally: sess.close() @pytest.fixture(scope="session", autouse=True) def create_test_database(): """ Create a clean database on every test case. We use the `sqlalchemy_utils` package here for a few helpers in consistently creating and dropping the database. """ if database_exists(url): drop_database(url) create_database(url) # Create the test database. #'create_all' does not work (because of undefined 'Geometry')! #declare PostGIS extension! fake_conn = _db_conn.raw_connection() fake_cur = fake_conn.cursor() fake_cur.execute("CREATE EXTENSION postgis") fake_conn.commit() # I also need the id_seq tables: fake_cur.execute("CREATE SEQUENCE public.auth_user_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_core_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_global_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_global_services_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_annotations_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_roles_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_docs_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_images_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.stationmeta_aux_urls_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.persons_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() fake_cur.execute("CREATE SEQUENCE public.organisations_id_seq AS integer START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1") fake_conn.commit() AuthUserBase.metadata.create_all(_db_conn) # Create the tables. ContactBase.metadata.create_all(_db_conn) # Create the tables. Base.metadata.create_all(_db_conn) # Create the tables. #try with the basics app.dependency_overrides[get_db] = get_test_db # Mock the Database Dependency yield # Run the tests. drop_database(url) # Drop the test database. @pytest.yield_fixture def test_db_session(): """Returns an sqlalchemy session, and after the test tears down everything properly.""" session = Session(bind=_db_conn) yield session # Drop all data after each test for tbl in reversed(ContactBase.metadata.sorted_tables + AuthUserBase.metadata.sorted_tables + Base.metadata.sorted_tables): _db_conn.execute(tbl.delete()) # put back the connection to the connection pool session.close() @pytest.fixture() def client(): """ When using the 'client' fixture in test cases, we'll get full database rollbacks between test cases: """ with TestClient(app) as client: yield client