feature: create multitasking
This commit is contained in:
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Tests for transcription endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from cpv3.modules.files.models import File
|
||||
from cpv3.modules.transcription.models import Transcription
|
||||
from cpv3.modules.users.models import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def source_file(test_db_session: AsyncSession, test_user: User) -> File:
|
||||
"""Create a source file for transcription."""
|
||||
file = File(
|
||||
id=uuid.uuid4(),
|
||||
owner_id=test_user.id,
|
||||
original_filename="audio.mp3",
|
||||
path="uploads/audio.mp3",
|
||||
storage_backend="LOCAL",
|
||||
mime_type="audio/mpeg",
|
||||
size_bytes=5000000,
|
||||
is_uploaded=True,
|
||||
is_active=True,
|
||||
)
|
||||
test_db_session.add(file)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(file)
|
||||
return file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_transcription(
|
||||
test_db_session: AsyncSession, source_file: File
|
||||
) -> Transcription:
|
||||
"""Create a test transcription."""
|
||||
transcription = Transcription(
|
||||
id=uuid.uuid4(),
|
||||
source_file_id=source_file.id,
|
||||
engine="LOCAL_WHISPER",
|
||||
language="en",
|
||||
document={"segments": []},
|
||||
is_active=True,
|
||||
)
|
||||
test_db_session.add(transcription)
|
||||
await test_db_session.commit()
|
||||
await test_db_session.refresh(transcription)
|
||||
return transcription
|
||||
|
||||
|
||||
class TestListTranscriptionsEndpoint:
|
||||
"""Tests for GET /api/transcribe/transcriptions/."""
|
||||
|
||||
async def test_list_transcriptions(
|
||||
self, auth_client: AsyncClient, test_transcription: Transcription
|
||||
):
|
||||
"""Test listing transcriptions."""
|
||||
response = await auth_client.get("/api/transcribe/transcriptions/")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
async def test_list_transcriptions_unauthenticated(self, async_client: AsyncClient):
|
||||
"""Test listing transcriptions without auth returns 401."""
|
||||
response = await async_client.get("/api/transcribe/transcriptions/")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestCreateTranscriptionEndpoint:
|
||||
"""Tests for POST /api/transcribe/transcriptions/."""
|
||||
|
||||
async def test_create_transcription_success(
|
||||
self, auth_client: AsyncClient, source_file: File
|
||||
):
|
||||
"""Test creating a transcription entry."""
|
||||
response = await auth_client.post(
|
||||
"/api/transcribe/transcriptions/",
|
||||
json={
|
||||
"source_file_id": str(source_file.id),
|
||||
"engine": "LOCAL_WHISPER",
|
||||
"language": "en",
|
||||
"document": {
|
||||
"segments": [
|
||||
{
|
||||
"text": "Hello world",
|
||||
"semantic_tags": [],
|
||||
"structure_tags": [],
|
||||
"time": {"start": 0.0, "end": 2.0},
|
||||
"lines": [],
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["engine"] == "LOCAL_WHISPER"
|
||||
assert data["language"] == "en"
|
||||
|
||||
async def test_create_transcription_unauthenticated(
|
||||
self, async_client: AsyncClient
|
||||
):
|
||||
"""Test creating transcription without auth returns 401."""
|
||||
response = await async_client.post(
|
||||
"/api/transcribe/transcriptions/",
|
||||
json={
|
||||
"source_file_id": str(uuid.uuid4()),
|
||||
"document": {"segments": []},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestRetrieveTranscriptionEndpoint:
|
||||
"""Tests for GET /api/transcribe/transcriptions/{transcription_id}/."""
|
||||
|
||||
async def test_retrieve_transcription(
|
||||
self, auth_client: AsyncClient, test_transcription: Transcription
|
||||
):
|
||||
"""Test retrieving a transcription."""
|
||||
response = await auth_client.get(
|
||||
f"/api/transcribe/transcriptions/{test_transcription.id}/"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(test_transcription.id)
|
||||
assert data["engine"] == test_transcription.engine
|
||||
|
||||
async def test_retrieve_nonexistent_transcription(self, auth_client: AsyncClient):
|
||||
"""Test retrieving nonexistent transcription returns 404."""
|
||||
fake_id = uuid.uuid4()
|
||||
response = await auth_client.get(f"/api/transcribe/transcriptions/{fake_id}/")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestPatchTranscriptionEndpoint:
|
||||
"""Tests for PATCH /api/transcribe/transcriptions/{transcription_id}/."""
|
||||
|
||||
async def test_patch_transcription(
|
||||
self, auth_client: AsyncClient, test_transcription: Transcription
|
||||
):
|
||||
"""Test updating a transcription."""
|
||||
updated_document = {
|
||||
"segments": [
|
||||
{
|
||||
"text": "Updated text",
|
||||
"semantic_tags": [],
|
||||
"structure_tags": [],
|
||||
"time": {"start": 0.0, "end": 3.0},
|
||||
"lines": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = await auth_client.patch(
|
||||
f"/api/transcribe/transcriptions/{test_transcription.id}/",
|
||||
json={"document": updated_document},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["document"]["segments"][0]["text"] == "Updated text"
|
||||
|
||||
async def test_patch_nonexistent_transcription(self, auth_client: AsyncClient):
|
||||
"""Test patching nonexistent transcription returns 404."""
|
||||
fake_id = uuid.uuid4()
|
||||
response = await auth_client.patch(
|
||||
f"/api/transcribe/transcriptions/{fake_id}/",
|
||||
json={"document": {"segments": []}},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestDeleteTranscriptionEndpoint:
|
||||
"""Tests for DELETE /api/transcribe/transcriptions/{transcription_id}/."""
|
||||
|
||||
async def test_delete_transcription(
|
||||
self, auth_client: AsyncClient, test_transcription: Transcription
|
||||
):
|
||||
"""Test deleting a transcription."""
|
||||
response = await auth_client.delete(
|
||||
f"/api/transcribe/transcriptions/{test_transcription.id}/"
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
async def test_delete_nonexistent_transcription(self, auth_client: AsyncClient):
|
||||
"""Test deleting nonexistent transcription returns 404."""
|
||||
fake_id = uuid.uuid4()
|
||||
response = await auth_client.delete(
|
||||
f"/api/transcribe/transcriptions/{fake_id}/"
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestWhisperTranscribeEndpoint:
|
||||
"""Tests for POST /api/transcribe/whisper/."""
|
||||
|
||||
async def test_whisper_transcribe_success(self, auth_client: AsyncClient):
|
||||
"""Test Whisper transcription endpoint."""
|
||||
mock_result = {
|
||||
"segments": [
|
||||
{
|
||||
"text": "Hello from Whisper",
|
||||
"semantic_tags": [],
|
||||
"structure_tags": [],
|
||||
"time": {"start": 0.0, "end": 2.5},
|
||||
"lines": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"cpv3.modules.transcription.router.transcribe_with_whisper",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
response = await auth_client.post(
|
||||
"/api/transcribe/whisper/",
|
||||
json={
|
||||
"file_path": "uploads/audio.mp3",
|
||||
"model_name": "tiny",
|
||||
"language": "en",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_whisper_transcribe_unauthenticated(self, async_client: AsyncClient):
|
||||
"""Test Whisper transcription without auth returns 401."""
|
||||
response = await async_client.post(
|
||||
"/api/transcribe/whisper/",
|
||||
json={"file_path": "test.mp3"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestGoogleSpeechTranscribeEndpoint:
|
||||
"""Tests for POST /api/transcribe/google-speech/."""
|
||||
|
||||
async def test_google_speech_transcribe_success(self, auth_client: AsyncClient):
|
||||
"""Test Google Speech transcription endpoint."""
|
||||
mock_result = {
|
||||
"segments": [
|
||||
{
|
||||
"text": "Hello from Google",
|
||||
"semantic_tags": [],
|
||||
"structure_tags": [],
|
||||
"time": {"start": 0.0, "end": 2.0},
|
||||
"lines": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"cpv3.modules.transcription.router.transcribe_with_google_speech",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
response = await auth_client.post(
|
||||
"/api/transcribe/google-speech/",
|
||||
json={
|
||||
"file_path": "uploads/audio.mp3",
|
||||
"language_codes": ["en-US"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_google_speech_transcribe_unauthenticated(
|
||||
self, async_client: AsyncClient
|
||||
):
|
||||
"""Test Google Speech transcription without auth returns 401."""
|
||||
response = await async_client.post(
|
||||
"/api/transcribe/google-speech/",
|
||||
json={"file_path": "test.mp3"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
Reference in New Issue
Block a user