import warnings from collections.abc import AsyncGenerator import pytest from aiopg.sa import Engine, SAConnection, create_engine from fastapi import FastAPI from httpx import AsyncClient from sqlalchemy import Table from sqlalchemy.dialects import postgresql from sqlalchemy.sql.ddl import CreateTable from src.api_app import create_app from src.settings import WebAppSettings @pytest.fixture def test_settings() -> WebAppSettings: return WebAppSettings( postgres_dsn="postgresql://messenger:messenger@localhost:5432/messenger_test", port=8000, ) @pytest.fixture def test_app(test_settings: WebAppSettings) -> FastAPI: return create_app(settings=test_settings) @pytest.fixture async def test_client(test_app: FastAPI) -> AsyncGenerator[AsyncClient, None]: async with AsyncClient(app=test_app, base_url="http://test/api/v1") as client: yield client @pytest.fixture def sa_tables(): """ Фикстура, с помощью которой можно локально переопределить перечень создаваемых таблиц. """ warnings.warn( "Please, override `sa_tables` fixture if you are using `db_engine`.", stacklevel=2, ) return [] @pytest.fixture def sa_enums(): """ Фикстура, с помощью которой можно локально переопределить перечень создаваемых перечислений. """ return [] @pytest.fixture(autouse=True) async def db_engine(test_settings: WebAppSettings, sa_tables, sa_enums) -> AsyncGenerator[Engine, None]: async with create_engine(test_settings.postgres_dsn) as engine: async with engine.acquire() as connection: await drop_tables(connection) await create_enums(connection, sa_enums) await create_tables(connection, sa_tables) yield engine @pytest.fixture async def connection(db_engine: Engine) -> AsyncGenerator[SAConnection, None]: async with db_engine.acquire() as connection: yield connection async def drop_tables(connection: SAConnection): await connection.execute("DROP SCHEMA public CASCADE;") await connection.execute("CREATE SCHEMA public;") async def create_tables(connection: SAConnection, tables: list[Table]): for table in tables: ddl = str(CreateTable(table).compile(dialect=postgresql.dialect())) await connection.execute(ddl) async def create_enums(connection, enums): for enum in enums: ddl = str(postgresql.CreateEnumType(enum).compile(dialect=postgresql.dialect())) await connection.execute(ddl)