75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from cpv3.db.base import utcnow
|
|
from cpv3.modules.project_workspaces.models import ProjectWorkspace
|
|
|
|
|
|
class WorkspaceRevisionConflictError(RuntimeError):
|
|
"""Raised when the optimistic workspace revision check fails."""
|
|
|
|
|
|
class ProjectWorkspaceRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def get_by_project_id(self, project_id: uuid.UUID) -> ProjectWorkspace | None:
|
|
result = await self._session.execute(
|
|
select(ProjectWorkspace).where(ProjectWorkspace.project_id == project_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create(self, *, project_id: uuid.UUID, state: dict) -> ProjectWorkspace:
|
|
workspace = ProjectWorkspace(project_id=project_id, revision=0, state=state)
|
|
self._session.add(workspace)
|
|
await self._session.commit()
|
|
await self._session.refresh(workspace)
|
|
return workspace
|
|
|
|
async def get_or_create(self, *, project_id: uuid.UUID, state: dict) -> ProjectWorkspace:
|
|
workspace = await self.get_by_project_id(project_id)
|
|
if workspace is not None:
|
|
return workspace
|
|
|
|
try:
|
|
return await self.create(project_id=project_id, state=state)
|
|
except IntegrityError:
|
|
await self._session.rollback()
|
|
workspace = await self.get_by_project_id(project_id)
|
|
if workspace is None:
|
|
raise
|
|
return workspace
|
|
|
|
async def update_state(
|
|
self,
|
|
*,
|
|
project_id: uuid.UUID,
|
|
expected_revision: int,
|
|
state: dict,
|
|
) -> ProjectWorkspace:
|
|
stmt = (
|
|
update(ProjectWorkspace)
|
|
.where(ProjectWorkspace.project_id == project_id)
|
|
.where(ProjectWorkspace.revision == expected_revision)
|
|
.values(
|
|
state=state,
|
|
revision=expected_revision + 1,
|
|
updated_at=utcnow(),
|
|
)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
if result.rowcount != 1:
|
|
await self._session.rollback()
|
|
raise WorkspaceRevisionConflictError
|
|
|
|
await self._session.commit()
|
|
workspace = await self.get_by_project_id(project_id)
|
|
if workspace is None:
|
|
raise RuntimeError("Workspace disappeared after update")
|
|
return workspace
|