Files
2026-04-27 23:19:04 +03:00

251 lines
7.0 KiB
Python

"""
Shared test fixtures and configuration.
"""
from __future__ import annotations
import uuid
from datetime import timedelta
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from cpv3.db.base import Base
from cpv3.db.session import get_db
from cpv3.infrastructure.auth import get_current_user
from cpv3.infrastructure.deps import get_storage
from cpv3.infrastructure.security import create_token, hash_password
from cpv3.main import app
from cpv3.modules.users.models import User
@pytest.fixture
async def test_engine():
"""Create a test database engine with tables."""
with NamedTemporaryFile(suffix=".sqlite3", delete=False) as tmp_db:
db_path = Path(tmp_db.name)
sync_engine = create_engine(f"sqlite:///{db_path}", echo=False)
Base.metadata.create_all(bind=sync_engine)
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False)
try:
yield engine
finally:
await engine.dispose()
Base.metadata.drop_all(bind=sync_engine)
sync_engine.dispose()
db_path.unlink(missing_ok=True)
@pytest.fixture
async def test_db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create a test database session with per-test transaction isolation."""
async_session = async_sessionmaker(
bind=test_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session() as session:
yield session
@pytest.fixture
async def test_user(test_db_session: AsyncSession) -> User:
"""Create a regular test user."""
user = User(
id=uuid.uuid4(),
username="testuser",
email="test@example.com",
password_hash=hash_password("testpassword"),
first_name="Test",
last_name="User",
is_active=True,
is_staff=False,
is_superuser=False,
)
test_db_session.add(user)
await test_db_session.commit()
await test_db_session.refresh(user)
return user
@pytest.fixture
async def staff_user(test_db_session: AsyncSession) -> User:
"""Create a staff test user."""
user = User(
id=uuid.uuid4(),
username="staffuser",
email="staff@example.com",
password_hash=hash_password("staffpassword"),
first_name="Staff",
last_name="User",
is_active=True,
is_staff=True,
is_superuser=False,
)
test_db_session.add(user)
await test_db_session.commit()
await test_db_session.refresh(user)
return user
@pytest.fixture
async def other_user(test_db_session: AsyncSession) -> User:
"""Create another regular user for permission testing."""
user = User(
id=uuid.uuid4(),
username="otheruser",
email="other@example.com",
password_hash=hash_password("otherpassword"),
first_name="Other",
last_name="User",
is_active=True,
is_staff=False,
is_superuser=False,
)
test_db_session.add(user)
await test_db_session.commit()
await test_db_session.refresh(user)
return user
@pytest.fixture
def auth_headers(test_user: User) -> dict[str, str]:
"""Generate auth headers with valid JWT for the test user."""
token = create_token(
subject=str(test_user.id),
token_type="access",
expires_in=timedelta(hours=1),
)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def staff_auth_headers(staff_user: User) -> dict[str, str]:
"""Generate auth headers with valid JWT for the staff user."""
token = create_token(
subject=str(staff_user.id),
token_type="access",
expires_in=timedelta(hours=1),
)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def other_auth_headers(other_user: User) -> dict[str, str]:
"""Generate auth headers with valid JWT for the other user."""
token = create_token(
subject=str(other_user.id),
token_type="access",
expires_in=timedelta(hours=1),
)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def mock_storage() -> MagicMock:
"""Create a mock storage service."""
storage = MagicMock()
storage.upload_fileobj = AsyncMock(return_value="uploads/test-file.txt")
storage.exists = AsyncMock(return_value=True)
file_info = MagicMock()
file_info.file_path = "uploads/test-file.txt"
file_info.file_url = "http://example.com/uploads/test-file.txt"
file_info.file_size = 1024
file_info.filename = "test-file.txt"
storage.get_file_info = AsyncMock(return_value=file_info)
return storage
@pytest.fixture
async def async_client(
test_db_session: AsyncSession,
mock_storage: MagicMock,
) -> AsyncGenerator[AsyncClient, None]:
"""Create async test client with dependency overrides (no auth override)."""
async def override_get_db():
yield test_db_session
async def override_get_storage():
return mock_storage
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_storage] = override_get_storage
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
) as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
async def auth_client(
test_db_session: AsyncSession,
test_user: User,
mock_storage: MagicMock,
) -> AsyncGenerator[AsyncClient, None]:
"""Create async test client with auth dependency overridden to return test_user."""
async def override_get_db():
yield test_db_session
def override_get_current_user():
return test_user
async def override_get_storage():
return mock_storage
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
app.dependency_overrides[get_storage] = override_get_storage
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
) as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
async def staff_client(
test_db_session: AsyncSession,
staff_user: User,
mock_storage: MagicMock,
) -> AsyncGenerator[AsyncClient, None]:
"""Create async test client with auth dependency overridden to return staff_user."""
async def override_get_db():
yield test_db_session
def override_get_current_user():
return staff_user
async def override_get_storage():
return mock_storage
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
app.dependency_overrides[get_storage] = override_get_storage
async with AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
) as client:
yield client
app.dependency_overrides.clear()