80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
from sqlalchemy import Select, func, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from cpv3.modules.notifications.models import Notification
|
|
from cpv3.modules.notifications.schemas import NotificationCreate
|
|
|
|
|
|
class NotificationRepository:
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def list_for_user(
|
|
self,
|
|
user_id: uuid.UUID,
|
|
*,
|
|
limit: int = 50,
|
|
unread_only: bool = False,
|
|
) -> list[Notification]:
|
|
stmt: Select[tuple[Notification]] = (
|
|
select(Notification)
|
|
.where(Notification.user_id == user_id)
|
|
.where(Notification.is_active.is_(True))
|
|
.order_by(Notification.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
if unread_only:
|
|
stmt = stmt.where(Notification.is_read.is_(False))
|
|
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def create(self, data: NotificationCreate) -> Notification:
|
|
notification = Notification(
|
|
user_id=data.user_id,
|
|
job_id=data.job_id,
|
|
project_id=data.project_id,
|
|
notification_type=data.notification_type,
|
|
title=data.title,
|
|
message=data.message,
|
|
payload=data.payload,
|
|
)
|
|
self._session.add(notification)
|
|
await self._session.commit()
|
|
await self._session.refresh(notification)
|
|
return notification
|
|
|
|
async def mark_read(self, notification_id: uuid.UUID, user_id: uuid.UUID) -> bool:
|
|
result = await self._session.execute(
|
|
update(Notification)
|
|
.where(Notification.id == notification_id)
|
|
.where(Notification.user_id == user_id)
|
|
.values(is_read=True)
|
|
)
|
|
await self._session.commit()
|
|
return result.rowcount > 0 # type: ignore[union-attr]
|
|
|
|
async def mark_all_read(self, user_id: uuid.UUID) -> int:
|
|
result = await self._session.execute(
|
|
update(Notification)
|
|
.where(Notification.user_id == user_id)
|
|
.where(Notification.is_read.is_(False))
|
|
.values(is_read=True)
|
|
)
|
|
await self._session.commit()
|
|
return result.rowcount # type: ignore[return-value]
|
|
|
|
async def count_unread(self, user_id: uuid.UUID) -> int:
|
|
result = await self._session.execute(
|
|
select(func.count())
|
|
.select_from(Notification)
|
|
.where(Notification.user_id == user_id)
|
|
.where(Notification.is_active.is_(True))
|
|
.where(Notification.is_read.is_(False))
|
|
)
|
|
return result.scalar_one()
|