new features
This commit is contained in:
+38
-34
@@ -17,44 +17,48 @@ _bearer = HTTPBearer(auto_error=True)
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(_bearer),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
db: AsyncSession = Depends(get_db, use_cache=False),
|
||||
) -> User:
|
||||
token = credentials.credentials
|
||||
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except ExpiredSignatureError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
) from e
|
||||
except InvalidTokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
token = credentials.credentials
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except ExpiredSignatureError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
) from e
|
||||
except InvalidTokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
try:
|
||||
user_id = uuid.UUID(str(sub))
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
user_repo = UserRepository(db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
|
||||
)
|
||||
try:
|
||||
user_id = uuid.UUID(str(sub))
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
) from e
|
||||
|
||||
return user
|
||||
user_repo = UserRepository(db)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
|
||||
)
|
||||
|
||||
return user
|
||||
finally:
|
||||
# Free the auth session immediately so long-running handlers don't pin a pool slot.
|
||||
await db.close()
|
||||
|
||||
@@ -17,7 +17,11 @@ class Settings(BaseSettings):
|
||||
# App
|
||||
debug: bool = Field(default=True, alias="DEBUG")
|
||||
cors_allowed_origins: list[str] = Field(
|
||||
default_factory=lambda: ["http://localhost:3000", "http://localhost:8000"],
|
||||
default_factory=lambda: [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:3001",
|
||||
"http://localhost:8000",
|
||||
],
|
||||
alias="CORS_ALLOWED_ORIGINS",
|
||||
)
|
||||
|
||||
@@ -37,6 +41,13 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
database_url: str | None = Field(default=None, alias="DATABASE_URL")
|
||||
db_pool_size: int = Field(default=5, alias="DB_POOL_SIZE")
|
||||
db_max_overflow: int = Field(default=10, alias="DB_MAX_OVERFLOW")
|
||||
db_pool_timeout: int = Field(default=30, alias="DB_POOL_TIMEOUT")
|
||||
db_pool_recycle_seconds: int = Field(
|
||||
default=1800,
|
||||
alias="DB_POOL_RECYCLE_SECONDS",
|
||||
)
|
||||
|
||||
# Storage
|
||||
storage_backend: str = Field(default="S3", alias="STORAGE_BACKEND")
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Storage utility helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cpv3.modules.users.models import User
|
||||
|
||||
|
||||
def get_user_folder(user: User) -> str:
|
||||
"""Return the per-user S3 folder prefix: ``<username>_<user_id>``."""
|
||||
return f"{user.username}_{user.id}"
|
||||
Reference in New Issue
Block a user