Files
main_backend/tests/integration/test_transcription_endpoints.py
2026-02-04 02:19:50 +03:00

296 lines
9.4 KiB
Python

"""
Tests for transcription endpoints.
"""
from __future__ import annotations
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from cpv3.modules.files.models import File
from cpv3.modules.transcription.models import Transcription
from cpv3.modules.users.models import User
@pytest.fixture
async def source_file(test_db_session: AsyncSession, test_user: User) -> File:
"""Create a source file for transcription."""
file = File(
id=uuid.uuid4(),
owner_id=test_user.id,
original_filename="audio.mp3",
path="uploads/audio.mp3",
storage_backend="LOCAL",
mime_type="audio/mpeg",
size_bytes=5000000,
is_uploaded=True,
is_active=True,
)
test_db_session.add(file)
await test_db_session.commit()
await test_db_session.refresh(file)
return file
@pytest.fixture
async def test_transcription(
test_db_session: AsyncSession, source_file: File
) -> Transcription:
"""Create a test transcription."""
transcription = Transcription(
id=uuid.uuid4(),
source_file_id=source_file.id,
engine="LOCAL_WHISPER",
language="en",
document={"segments": []},
is_active=True,
)
test_db_session.add(transcription)
await test_db_session.commit()
await test_db_session.refresh(transcription)
return transcription
class TestListTranscriptionsEndpoint:
"""Tests for GET /api/transcribe/transcriptions/."""
async def test_list_transcriptions(
self, auth_client: AsyncClient, test_transcription: Transcription
):
"""Test listing transcriptions."""
response = await auth_client.get("/api/transcribe/transcriptions/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
async def test_list_transcriptions_unauthenticated(self, async_client: AsyncClient):
"""Test listing transcriptions without auth returns 401."""
response = await async_client.get("/api/transcribe/transcriptions/")
assert response.status_code == 401
class TestCreateTranscriptionEndpoint:
"""Tests for POST /api/transcribe/transcriptions/."""
async def test_create_transcription_success(
self, auth_client: AsyncClient, source_file: File
):
"""Test creating a transcription entry."""
response = await auth_client.post(
"/api/transcribe/transcriptions/",
json={
"source_file_id": str(source_file.id),
"engine": "LOCAL_WHISPER",
"language": "en",
"document": {
"segments": [
{
"text": "Hello world",
"semantic_tags": [],
"structure_tags": [],
"time": {"start": 0.0, "end": 2.0},
"lines": [],
}
]
},
},
)
assert response.status_code == 201
data = response.json()
assert data["engine"] == "LOCAL_WHISPER"
assert data["language"] == "en"
async def test_create_transcription_unauthenticated(
self, async_client: AsyncClient
):
"""Test creating transcription without auth returns 401."""
response = await async_client.post(
"/api/transcribe/transcriptions/",
json={
"source_file_id": str(uuid.uuid4()),
"document": {"segments": []},
},
)
assert response.status_code == 401
class TestRetrieveTranscriptionEndpoint:
"""Tests for GET /api/transcribe/transcriptions/{transcription_id}/."""
async def test_retrieve_transcription(
self, auth_client: AsyncClient, test_transcription: Transcription
):
"""Test retrieving a transcription."""
response = await auth_client.get(
f"/api/transcribe/transcriptions/{test_transcription.id}/"
)
assert response.status_code == 200
data = response.json()
assert data["id"] == str(test_transcription.id)
assert data["engine"] == test_transcription.engine
async def test_retrieve_nonexistent_transcription(self, auth_client: AsyncClient):
"""Test retrieving nonexistent transcription returns 404."""
fake_id = uuid.uuid4()
response = await auth_client.get(f"/api/transcribe/transcriptions/{fake_id}/")
assert response.status_code == 404
class TestPatchTranscriptionEndpoint:
"""Tests for PATCH /api/transcribe/transcriptions/{transcription_id}/."""
async def test_patch_transcription(
self, auth_client: AsyncClient, test_transcription: Transcription
):
"""Test updating a transcription."""
updated_document = {
"segments": [
{
"text": "Updated text",
"semantic_tags": [],
"structure_tags": [],
"time": {"start": 0.0, "end": 3.0},
"lines": [],
}
]
}
response = await auth_client.patch(
f"/api/transcribe/transcriptions/{test_transcription.id}/",
json={"document": updated_document},
)
assert response.status_code == 200
data = response.json()
assert data["document"]["segments"][0]["text"] == "Updated text"
async def test_patch_nonexistent_transcription(self, auth_client: AsyncClient):
"""Test patching nonexistent transcription returns 404."""
fake_id = uuid.uuid4()
response = await auth_client.patch(
f"/api/transcribe/transcriptions/{fake_id}/",
json={"document": {"segments": []}},
)
assert response.status_code == 404
class TestDeleteTranscriptionEndpoint:
"""Tests for DELETE /api/transcribe/transcriptions/{transcription_id}/."""
async def test_delete_transcription(
self, auth_client: AsyncClient, test_transcription: Transcription
):
"""Test deleting a transcription."""
response = await auth_client.delete(
f"/api/transcribe/transcriptions/{test_transcription.id}/"
)
assert response.status_code == 204
async def test_delete_nonexistent_transcription(self, auth_client: AsyncClient):
"""Test deleting nonexistent transcription returns 404."""
fake_id = uuid.uuid4()
response = await auth_client.delete(
f"/api/transcribe/transcriptions/{fake_id}/"
)
assert response.status_code == 404
class TestWhisperTranscribeEndpoint:
"""Tests for POST /api/transcribe/whisper/."""
async def test_whisper_transcribe_success(self, auth_client: AsyncClient):
"""Test Whisper transcription endpoint."""
mock_result = {
"segments": [
{
"text": "Hello from Whisper",
"semantic_tags": [],
"structure_tags": [],
"time": {"start": 0.0, "end": 2.5},
"lines": [],
}
]
}
with patch(
"cpv3.modules.transcription.router.transcribe_with_whisper",
new_callable=AsyncMock,
return_value=mock_result,
):
response = await auth_client.post(
"/api/transcribe/whisper/",
json={
"file_path": "uploads/audio.mp3",
"model_name": "tiny",
"language": "en",
},
)
assert response.status_code == 200
async def test_whisper_transcribe_unauthenticated(self, async_client: AsyncClient):
"""Test Whisper transcription without auth returns 401."""
response = await async_client.post(
"/api/transcribe/whisper/",
json={"file_path": "test.mp3"},
)
assert response.status_code == 401
class TestGoogleSpeechTranscribeEndpoint:
"""Tests for POST /api/transcribe/google-speech/."""
async def test_google_speech_transcribe_success(self, auth_client: AsyncClient):
"""Test Google Speech transcription endpoint."""
mock_result = {
"segments": [
{
"text": "Hello from Google",
"semantic_tags": [],
"structure_tags": [],
"time": {"start": 0.0, "end": 2.0},
"lines": [],
}
]
}
with patch(
"cpv3.modules.transcription.router.transcribe_with_google_speech",
new_callable=AsyncMock,
return_value=mock_result,
):
response = await auth_client.post(
"/api/transcribe/google-speech/",
json={
"file_path": "uploads/audio.mp3",
"language_codes": ["en-US"],
},
)
assert response.status_code == 200
async def test_google_speech_transcribe_unauthenticated(
self, async_client: AsyncClient
):
"""Test Google Speech transcription without auth returns 401."""
response = await async_client.post(
"/api/transcribe/google-speech/",
json={"file_path": "test.mp3"},
)
assert response.status_code == 401