from __future__ import annotations import uuid from sqlalchemy import Select, func, select, update from sqlalchemy.ext.asyncio import AsyncSession from cpv3.modules.notifications.models import Notification from cpv3.modules.notifications.schemas import NotificationCreate class NotificationRepository: def __init__(self, session: AsyncSession) -> None: self._session = session async def list_for_user( self, user_id: uuid.UUID, *, limit: int = 50, unread_only: bool = False, ) -> list[Notification]: stmt: Select[tuple[Notification]] = ( select(Notification) .where(Notification.user_id == user_id) .where(Notification.is_active.is_(True)) .order_by(Notification.created_at.desc()) .limit(limit) ) if unread_only: stmt = stmt.where(Notification.is_read.is_(False)) result = await self._session.execute(stmt) return list(result.scalars().all()) async def create(self, data: NotificationCreate) -> Notification: notification = Notification( user_id=data.user_id, job_id=data.job_id, project_id=data.project_id, notification_type=data.notification_type, title=data.title, message=data.message, payload=data.payload, ) self._session.add(notification) await self._session.commit() await self._session.refresh(notification) return notification async def mark_read(self, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool: result = await self._session.execute( update(Notification) .where(Notification.id == notification_id) .where(Notification.user_id == user_id) .values(is_read=True) ) await self._session.commit() return result.rowcount > 0 # type: ignore[union-attr] async def mark_all_read(self, user_id: uuid.UUID) -> int: result = await self._session.execute( update(Notification) .where(Notification.user_id == user_id) .where(Notification.is_read.is_(False)) .values(is_read=True) ) await self._session.commit() return result.rowcount # type: ignore[return-value] async def count_unread(self, user_id: uuid.UUID) -> int: result = await self._session.execute( select(func.count()) .select_from(Notification) .where(Notification.user_id == user_id) .where(Notification.is_active.is_(True)) .where(Notification.is_read.is_(False)) ) return result.scalar_one()