""" Tests for transcription endpoints. """ from __future__ import annotations import uuid from unittest.mock import AsyncMock, patch import pytest from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession from cpv3.modules.files.models import File from cpv3.modules.transcription.models import Transcription from cpv3.modules.users.models import User @pytest.fixture async def source_file(test_db_session: AsyncSession, test_user: User) -> File: """Create a source file for transcription.""" file = File( id=uuid.uuid4(), owner_id=test_user.id, original_filename="audio.mp3", path="uploads/audio.mp3", storage_backend="LOCAL", mime_type="audio/mpeg", size_bytes=5000000, is_uploaded=True, is_active=True, ) test_db_session.add(file) await test_db_session.commit() await test_db_session.refresh(file) return file @pytest.fixture async def test_transcription( test_db_session: AsyncSession, source_file: File ) -> Transcription: """Create a test transcription.""" transcription = Transcription( id=uuid.uuid4(), source_file_id=source_file.id, engine="LOCAL_WHISPER", language="en", document={"segments": []}, is_active=True, ) test_db_session.add(transcription) await test_db_session.commit() await test_db_session.refresh(transcription) return transcription class TestListTranscriptionsEndpoint: """Tests for GET /api/transcribe/transcriptions/.""" async def test_list_transcriptions( self, auth_client: AsyncClient, test_transcription: Transcription ): """Test listing transcriptions.""" response = await auth_client.get("/api/transcribe/transcriptions/") assert response.status_code == 200 data = response.json() assert isinstance(data, list) async def test_list_transcriptions_unauthenticated(self, async_client: AsyncClient): """Test listing transcriptions without auth returns 401.""" response = await async_client.get("/api/transcribe/transcriptions/") assert response.status_code == 401 class TestCreateTranscriptionEndpoint: """Tests for POST /api/transcribe/transcriptions/.""" async def test_create_transcription_success( self, auth_client: AsyncClient, source_file: File ): """Test creating a transcription entry.""" response = await auth_client.post( "/api/transcribe/transcriptions/", json={ "source_file_id": str(source_file.id), "engine": "LOCAL_WHISPER", "language": "en", "document": { "segments": [ { "text": "Hello world", "semantic_tags": [], "structure_tags": [], "time": {"start": 0.0, "end": 2.0}, "lines": [], } ] }, }, ) assert response.status_code == 201 data = response.json() assert data["engine"] == "LOCAL_WHISPER" assert data["language"] == "en" async def test_create_transcription_unauthenticated( self, async_client: AsyncClient ): """Test creating transcription without auth returns 401.""" response = await async_client.post( "/api/transcribe/transcriptions/", json={ "source_file_id": str(uuid.uuid4()), "document": {"segments": []}, }, ) assert response.status_code == 401 class TestRetrieveTranscriptionEndpoint: """Tests for GET /api/transcribe/transcriptions/{transcription_id}/.""" async def test_retrieve_transcription( self, auth_client: AsyncClient, test_transcription: Transcription ): """Test retrieving a transcription.""" response = await auth_client.get( f"/api/transcribe/transcriptions/{test_transcription.id}/" ) assert response.status_code == 200 data = response.json() assert data["id"] == str(test_transcription.id) assert data["engine"] == test_transcription.engine async def test_retrieve_nonexistent_transcription(self, auth_client: AsyncClient): """Test retrieving nonexistent transcription returns 404.""" fake_id = uuid.uuid4() response = await auth_client.get(f"/api/transcribe/transcriptions/{fake_id}/") assert response.status_code == 404 class TestPatchTranscriptionEndpoint: """Tests for PATCH /api/transcribe/transcriptions/{transcription_id}/.""" async def test_patch_transcription( self, auth_client: AsyncClient, test_transcription: Transcription ): """Test updating a transcription.""" updated_document = { "segments": [ { "text": "Updated text", "semantic_tags": [], "structure_tags": [], "time": {"start": 0.0, "end": 3.0}, "lines": [], } ] } response = await auth_client.patch( f"/api/transcribe/transcriptions/{test_transcription.id}/", json={"document": updated_document}, ) assert response.status_code == 200 data = response.json() assert data["document"]["segments"][0]["text"] == "Updated text" async def test_patch_nonexistent_transcription(self, auth_client: AsyncClient): """Test patching nonexistent transcription returns 404.""" fake_id = uuid.uuid4() response = await auth_client.patch( f"/api/transcribe/transcriptions/{fake_id}/", json={"document": {"segments": []}}, ) assert response.status_code == 404 class TestDeleteTranscriptionEndpoint: """Tests for DELETE /api/transcribe/transcriptions/{transcription_id}/.""" async def test_delete_transcription( self, auth_client: AsyncClient, test_transcription: Transcription ): """Test deleting a transcription.""" response = await auth_client.delete( f"/api/transcribe/transcriptions/{test_transcription.id}/" ) assert response.status_code == 204 async def test_delete_nonexistent_transcription(self, auth_client: AsyncClient): """Test deleting nonexistent transcription returns 404.""" fake_id = uuid.uuid4() response = await auth_client.delete( f"/api/transcribe/transcriptions/{fake_id}/" ) assert response.status_code == 404 class TestWhisperTranscribeEndpoint: """Tests for POST /api/transcribe/whisper/.""" async def test_whisper_transcribe_success(self, auth_client: AsyncClient): """Test Whisper transcription endpoint.""" mock_result = { "segments": [ { "text": "Hello from Whisper", "semantic_tags": [], "structure_tags": [], "time": {"start": 0.0, "end": 2.5}, "lines": [], } ] } with patch( "cpv3.modules.transcription.router.transcribe_with_whisper", new_callable=AsyncMock, return_value=mock_result, ): response = await auth_client.post( "/api/transcribe/whisper/", json={ "file_path": "uploads/audio.mp3", "model_name": "tiny", "language": "en", }, ) assert response.status_code == 200 async def test_whisper_transcribe_unauthenticated(self, async_client: AsyncClient): """Test Whisper transcription without auth returns 401.""" response = await async_client.post( "/api/transcribe/whisper/", json={"file_path": "test.mp3"}, ) assert response.status_code == 401 class TestGoogleSpeechTranscribeEndpoint: """Tests for POST /api/transcribe/google-speech/.""" async def test_google_speech_transcribe_success(self, auth_client: AsyncClient): """Test Google Speech transcription endpoint.""" mock_result = { "segments": [ { "text": "Hello from Google", "semantic_tags": [], "structure_tags": [], "time": {"start": 0.0, "end": 2.0}, "lines": [], } ] } with patch( "cpv3.modules.transcription.router.transcribe_with_google_speech", new_callable=AsyncMock, return_value=mock_result, ): response = await auth_client.post( "/api/transcribe/google-speech/", json={ "file_path": "uploads/audio.mp3", "language_codes": ["en-US"], }, ) assert response.status_code == 200 async def test_google_speech_transcribe_unauthenticated( self, async_client: AsyncClient ): """Test Google Speech transcription without auth returns 401.""" response = await async_client.post( "/api/transcribe/google-speech/", json={"file_path": "test.mp3"}, ) assert response.status_code == 401