Source code for derp.auth.native_client

"""Core authentication service."""

from __future__ import annotations

import logging
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any

from derp.auth.base import BaseAuthClient
from derp.auth.email import EmailClient
from derp.auth.exceptions import (
    ConfirmationURLMissingError,
    EmailAlreadyExistsError,
    InvalidCredentialsError,
    InvalidTokenError,
    LastOwnerError,
    MemberAlreadyExistsError,
    OrgMemberNotFoundError,
    OrgNotFoundError,
    OrgSlugConflictError,
    PasswordValidationError,
    SignupDisabledError,
    UserNotFoundError,
)
from derp.auth.jwt import TokenPair, create_token_pair, decode_token
from derp.auth.models import (
    AuthOrganization,
    AuthOrgMember,
    AuthProvider,
    AuthRequest,
    AuthResult,
    AuthSession,
    AuthUser,
    OrgInfo,
    OrgMemberInfo,
    SessionInfo,
    UserInfo,
)
from derp.auth.password import (
    Argon2Hasher,
    PasswordHasher,
    generate_secure_token,
    validate_password,
)
from derp.auth.providers.base import BaseOAuthProvider
from derp.auth.providers.github import GitHubProvider
from derp.auth.providers.google import GoogleProvider
from derp.config import NativeAuthConfig
from derp.kv.base import KVClient
from derp.orm import DatabaseEngine

logger = logging.getLogger(__name__)


[docs] class NativeAuthClient(BaseAuthClient): """Native authentication client (email/password, magic link, OAuth)."""
[docs] def __init__(self, config: NativeAuthConfig): self._config: NativeAuthConfig = config self._hasher: PasswordHasher = Argon2Hasher() self._email_client: EmailClient | None = None self._oauth_providers: dict[AuthProvider, BaseOAuthProvider] = {} self._database_client: DatabaseEngine | None = None self._kv_client: KVClient | None = None if self._config.google_oauth is not None: self._oauth_providers[AuthProvider.GOOGLE] = GoogleProvider( self._config.google_oauth ) if self._config.github_oauth is not None: self._oauth_providers[AuthProvider.GITHUB] = GitHubProvider( self._config.github_oauth )
[docs] def set_db(self, db: DatabaseEngine | None) -> None: """Set the database client.""" self._database_client = db
def _db(self) -> DatabaseEngine: """Get the database client.""" if self._database_client is None: raise ValueError("Database client not set. Must call `set_db()` first.") return self._database_client
[docs] def set_kv(self, kv: KVClient | None) -> None: """Set the KV store for caching and token storage.""" self._kv_client = kv
def _kv(self) -> KVClient: """Get the KV client. Required for token operations.""" if self._kv_client is None: raise ValueError( "KV client not set. Token operations (recovery, confirmation, " "magic link) require a KV store. Call `set_kv()` first." ) return self._kv_client
[docs] def set_email(self, email_client: EmailClient | None) -> None: """Set the email client.""" self._email_client = email_client
def _email(self) -> EmailClient: """Get the email client.""" if self._email_client is None: raise ValueError("Email client not set. Must call `set_email()` first.") return self._email_client @staticmethod def _extract_client_info( request: AuthRequest | None, user_agent: str | None, ip_address: str | None, ) -> tuple[str | None, str | None]: """Extract user_agent and ip_address from *request* when not given.""" if request is None: return user_agent, ip_address if user_agent is None: user_agent = request.headers.get("User-Agent") if ip_address is None: forwarded = request.headers.get("X-Forwarded-For") if forwarded: ip_address = forwarded.split(",")[0].strip() else: client = getattr(request, "client", None) if client is not None: ip_address = getattr(client, "host", None) return user_agent, ip_address async def _invalidate_user_cache( self, user_id: str | uuid.UUID, email: str | None = None ) -> None: """Invalidate cached user data in KV store.""" if self._kv_client is not None: await self._kv_client.delete( f"{self._config.cache_prefix}:user:{user_id}".encode() ) if email is not None: await self._kv_client.delete( f"{self._config.cache_prefix}:user:email:{email.lower()}".encode() ) # ========================================================================= # User Management # ========================================================================= def _to_user_info(self, user: AuthUser) -> UserInfo: """Convert an internal AuthUser ORM model to a public UserInfo.""" return UserInfo( id=str(user.id), email=user.email, first_name=user.first_name, last_name=user.last_name, username=user.username, image_url=user.image_url, role=user.role, is_active=user.is_active, is_superuser=user.is_superuser, created_at=user.created_at, updated_at=user.updated_at, last_sign_in_at=user.last_sign_in_at, email_confirmed_at=user.email_confirmed_at, metadata={ "provider": user.provider.value if hasattr(user.provider, "value") else user.provider, "provider_id": user.provider_id, }, ) async def _fetch_user(self, user_id: str | uuid.UUID) -> AuthUser | None: """Fetch a user by ID (internal, with caching).""" if self._config.use_kv_cache and self._kv_client is not None: cache_key = f"{self._config.cache_prefix}:user:{user_id}".encode() async def _compute() -> bytes: row = await ( self._db() .select(AuthUser) .where(AuthUser.id == str(user_id)) .first_or_none() ) if row is None: return b"" return row.to_json().encode() cached = await self._kv_client.guarded_get( cache_key, compute=_compute, ttl=self._config.cache_user_ttl_seconds, ) if cached == b"": return None return AuthUser.from_json(cached) return await ( self._db() .select(AuthUser) .where(AuthUser.id == str(user_id)) .first_or_none() )
[docs] async def get_user(self, user_id: str | uuid.UUID) -> UserInfo: """Get a user by their ID. Raises: UserNotFoundError: No user with that id. """ user = await self._fetch_user(user_id) if user is None: raise UserNotFoundError(f"User {user_id!r} not found") return self._to_user_info(user)
[docs] async def list_users( self, *, limit: int | None = None, offset: int | None = None ) -> list[UserInfo]: """List users ordered by creation date (newest first).""" q = self._db().select(AuthUser).order_by(AuthUser.created_at, asc=False) if limit is not None: q = q.limit(limit) if offset is not None: q = q.offset(offset) return [self._to_user_info(u) for u in await q.execute()]
async def _get_user_by_email(self, email: str) -> AuthUser | None: """Get a user by their email address (internal use only). Unlike ``get_user``, negative results (user not found) are **not** cached because email lookups are used in write paths (sign-up, OAuth) where a subsequent insert would leave a stale "not found" entry. """ normalized = email.lower() if self._config.use_kv_cache and self._kv_client is not None: cache_key = f"{self._config.cache_prefix}:user:email:{normalized}".encode() cached = await self._kv_client.get(cache_key) if cached is not None: return AuthUser.from_json(cached) user = await ( self._db() .select(AuthUser) .where(AuthUser.email == normalized) .first_or_none() ) if user is not None: await self._kv_client.set( cache_key, user.to_json().encode(), ttl=self._config.cache_user_ttl_seconds, ) return user return await ( self._db() .select(AuthUser) .where(AuthUser.email == normalized) .first_or_none() )
[docs] async def update_user( self, *, user_id: str | uuid.UUID, email: str | None = None, **kwargs: Any, ) -> UserInfo: """Update user data. Raises: UserNotFoundError: No user with that id. """ user = await self._fetch_user(user_id=user_id) if not user: raise UserNotFoundError(f"User {user_id!r} not found") updates: dict[str, Any] = {"updated_at": datetime.now(UTC)} if email is not None: updates["email"] = email.lower() for key, value in kwargs.items(): if key in AuthUser.get_columns(): updates[key] = value else: raise ValueError(f"Invalid user field: {key}.") [result] = await ( self._db() .update(AuthUser) .set(**updates) .where(AuthUser.id == str(user_id)) .returning(AuthUser) .execute() ) await self._invalidate_user_cache(user_id, user.email) if email is not None and email.lower() != user.email: await self._invalidate_user_cache(user_id, email) return self._to_user_info(result)
[docs] async def delete_user(self, user_id: str | uuid.UUID) -> bool: """Delete a user and all their sessions.""" row = await ( self._db() .select(AuthUser.id, AuthUser.email) .from_(AuthUser) .where(AuthUser.id == str(user_id)) .first_or_none() ) if not row: return False _, email = row await self.sign_out_all(user_id) await self._db().delete(AuthUser).where(AuthUser.id == str(user_id)).execute() await self._invalidate_user_cache(user_id, email) return True
[docs] async def count_users(self) -> int: """Return the total number of users.""" return await self._db().select(AuthUser).count()
# ========================================================================= # Email/Password Authentication # =========================================================================
[docs] async def sign_up( self, *, email: str, password: str, request: AuthRequest | None = None, confirmation_url: str | None = None, confirmation_subject: str = "Confirm your email address", user_agent: str | None = None, ip_address: str | None = None, **kwargs: Any, ) -> AuthResult: """Register a new user with email and password. Raises: SignupDisabledError: Signup is disabled in config. ConfirmationURLMissingError: confirmation_url omitted while ``enable_confirmation`` is on. PasswordValidationError: Password did not meet requirements. EmailAlreadyExistsError: An account with this email exists. """ user_agent, ip_address = self._extract_client_info( request, user_agent, ip_address ) if not self._config.enable_signup: raise SignupDisabledError() if confirmation_url is None and self._config.enable_confirmation: raise ConfirmationURLMissingError( "`confirmation_url` is required when confirmation is enabled." ) # Validate password validation = validate_password(self._config.password, password) if not validation.valid: raise PasswordValidationError("; ".join(validation.errors)) # Check if user exists exists = await ( self._db() .select(AuthUser.id) .from_(AuthUser) .where(AuthUser.email == email.lower()) .first_or_none() ) if exists: raise EmailAlreadyExistsError(email.lower()) # Create user hashed_password = await self._hasher.async_hash(password) now = datetime.now(UTC) email_confirmed_at = None if self._config.enable_confirmation else now vals: dict[str, Any] = {} for key, value in kwargs.items(): if key in AuthUser.get_columns(): vals[key] = value else: raise ValueError(f"Invalid user field: {key}.") user = await ( self._db() .insert(AuthUser) .values( email=email.lower(), encrypted_password=hashed_password, provider=AuthProvider.EMAIL, email_confirmed_at=email_confirmed_at, created_at=now, updated_at=now, last_sign_in_at=now, **vals, ) .returning(AuthUser) .execute() ) # Store confirmation token in KV and send email if needed if self._config.enable_confirmation: confirmation_token = generate_secure_token() ttl = self._config.confirmation_token_expire_hours * 3600 await self._kv().set( f"{self._config.cache_prefix}:confirmation:{confirmation_token}".encode(), str(user.id).encode(), ttl=ttl, ) await self._email().send_email( subject=confirmation_subject, to_email=email.lower(), template="confirmation.html", confirmation_url=f"{confirmation_url}?token={confirmation_token}", ) # Create session and tokens token_pair = await self._create_session( user.id, role=user.role, user_agent=user_agent, ip_address=ip_address, ) return AuthResult(user=self._to_user_info(user), tokens=token_pair)
[docs] async def sign_in_with_password( self, email: str, password: str, *, request: AuthRequest | None = None, first_name: str | None = None, last_name: str | None = None, user_agent: str | None = None, ip_address: str | None = None, ) -> AuthResult: """Sign in with email and password. Raises: InvalidCredentialsError: Email/password did not match, or the account is disabled / unconfirmed. The reason is logged at ``WARNING`` for ops; the caller-visible error is opaque to avoid email-enumeration leaks. """ user_agent, ip_address = self._extract_client_info( request, user_agent, ip_address ) user = await self._get_user_by_email(email=email.lower()) if not user: logger.warning("Sign-in failed: user not found for %s", email) raise InvalidCredentialsError() if not user.encrypted_password: logger.warning("Sign-in failed: no password set for %s", email) raise InvalidCredentialsError() if not await self._hasher.async_verify(password, user.encrypted_password): logger.warning("Sign-in failed: invalid password for %s", email) raise InvalidCredentialsError() if not user.is_active: logger.warning("Sign-in failed: account disabled for %s", email) raise InvalidCredentialsError() if self._config.enable_confirmation and not user.email_confirmed_at: logger.warning("Sign-in failed: email not confirmed for %s", email) raise InvalidCredentialsError() # Update last sign in (and rehash password if needed) in a single write now = datetime.now(UTC) updates: dict[str, Any] = {"last_sign_in_at": now, "updated_at": now} if self._hasher.needs_rehash(user.encrypted_password): updates["encrypted_password"] = await self._hasher.async_hash(password) [user] = await ( self._db() .update(AuthUser) .set(**updates) .where(AuthUser.id == user.id) .returning(AuthUser) .execute() ) await self._invalidate_user_cache(user.id, user.email) token_pair = await self._create_session( user.id, role=user.role, user_agent=user_agent, ip_address=ip_address, ) return AuthResult(user=self._to_user_info(user), tokens=token_pair)
# ========================================================================= # Magic Link Authentication # ========================================================================= # ========================================================================= # OAuth Authentication # =========================================================================
[docs] def get_oauth_provider(self, provider: str | AuthProvider) -> BaseOAuthProvider: if isinstance(provider, str): provider = AuthProvider(provider) oauth_provider = self._oauth_providers.get(provider) if oauth_provider is None: raise ValueError(f"OAuth provider not configured: {provider}") return oauth_provider
[docs] def get_oauth_authorization_url( self, provider: str | AuthProvider, state: str, scopes: list[str] | None = None, redirect_uri: str | None = None, ) -> str: """Get the OAuth authorization URL for a provider. Args: provider_name: Name of the OAuth provider state: CSRF protection state token scopes: Optional scopes to request redirect_uri: Optional redirect URI override Returns: Authorization URL """ oauth_provider = self.get_oauth_provider(provider) return oauth_provider.get_authorization_url(state, scopes, redirect_uri)
[docs] async def sign_in_with_oauth( self, provider: str | AuthProvider, code: str, *, redirect_uri: str | None = None, user_agent: str | None = None, ip_address: str | None = None, ) -> AuthResult: """Complete OAuth sign in with authorization code. Creates the user record if this email has never signed in before. Raises: InvalidCredentialsError: Provider rejected the code, or the matching account is disabled. """ oauth_provider = self.get_oauth_provider(provider) # Get user info from provider user_info = await oauth_provider.authenticate(code, redirect_uri) if user_info is None: raise InvalidCredentialsError("OAuth code rejected by provider") # Find or create user user = await self._get_user_by_email(email=user_info.email) now = datetime.now(UTC) if user: # Update existing user if not user.is_active: logger.warning( "OAuth sign-in failed: account disabled for %s", user.email, ) raise InvalidCredentialsError() updates: dict[str, Any] = { "last_sign_in_at": now, "updated_at": now, } # Update provider info if first time with this provider if user.provider == AuthProvider.EMAIL: updates["provider"] = provider updates["provider_id"] = user_info.id # Confirm email if provider verified it if user_info.email_verified and not user.email_confirmed_at: updates["email_confirmed_at"] = now [user] = await ( self._db() .update(AuthUser) .set(**updates) .where(AuthUser.id == user.id) .returning(AuthUser) .execute() ) await self._invalidate_user_cache(user.id, user.email) else: user = await ( self._db() .insert(AuthUser) .values( email=user_info.email.lower(), provider=provider, provider_id=user_info.id, email_confirmed_at=now if user_info.email_verified else None, created_at=now, updated_at=now, last_sign_in_at=now, ) .returning(AuthUser) .execute() ) token_pair = await self._create_session( user.id, role=user.role, user_agent=user_agent, ip_address=ip_address, ) return AuthResult(user=self._to_user_info(user), tokens=token_pair)
# ========================================================================= # Session Management # ========================================================================= async def _create_session( self, user_id: uuid.UUID, *, role: str = "default", user_agent: str | None = None, ip_address: str | None = None, ) -> TokenPair: """Create a new session and return tokens.""" now = datetime.now(UTC) not_after = now + timedelta(days=self._config.session_expire_days) refresh_token = generate_secure_token() session_id = await ( self._db() .insert(AuthSession) .values( user_id=user_id, token=refresh_token, role=role, user_agent=user_agent, ip_address=ip_address, not_after=not_after, created_at=now, ) .returning(AuthSession.session_id) .execute() ) return create_token_pair( self._config.jwt, user_id, session_id, refresh_token, extra_claims={"role": role}, )
[docs] async def refresh_token(self, refresh_token: str) -> TokenPair: """Refresh an access token using a refresh token. Implements token rotation for security. Happy path is 2 DB calls: one UPDATE…RETURNING to atomically revoke the old token, one INSERT for the new token. Raises: InvalidTokenError: Token is unknown, revoked, reused, or expired. Reuse detection additionally revokes every token for the session before raising. """ # Atomically revoke and return the token in one query. # If the token doesn't exist or is already revoked, this returns []. revoked_rows = await ( self._db() .update(AuthSession) .set(revoked=True) .eq(AuthSession.token, refresh_token) .not_(AuthSession.revoked) .returning(AuthSession) .execute() ) if not revoked_rows: # Token not found or already revoked — check which case. existing = await ( self._db() .select(AuthSession) .eq(AuthSession.token, refresh_token) .first_or_none() ) if existing is not None and existing.revoked: # Reuse detected — revoke all tokens for this session. await ( self._db() .update(AuthSession) .set(revoked=True) .eq(AuthSession.session_id, existing.session_id) .execute() ) if self._kv_client is not None: await self._kv_client.delete( f"{self._config.cache_prefix}:session:{existing.session_id}".encode() ) logger.warning("Refresh token reuse detected, all sessions revoked") raise InvalidTokenError("Refresh token reuse detected") logger.warning("Refresh token invalid or revoked") raise InvalidTokenError("Refresh token is invalid or revoked") [token_record] = revoked_rows if token_record.not_after < datetime.now(UTC): logger.warning("Refresh token failed: session expired") raise InvalidTokenError("Refresh token session has expired") # Insert rotated token new_refresh_token = generate_secure_token() await ( self._db() .insert(AuthSession) .values( user_id=token_record.user_id, session_id=token_record.session_id, token=new_refresh_token, role=token_record.role, user_agent=token_record.user_agent, ip_address=token_record.ip_address, not_after=token_record.not_after, created_at=datetime.now(UTC), ) .execute() ) # Invalidate stale session cache so next authenticate re-fetches if self._kv_client is not None: await self._kv_client.delete( f"{self._config.cache_prefix}:session:{token_record.session_id}".encode() ) return create_token_pair( self._config.jwt, token_record.user_id, token_record.session_id, new_refresh_token, extra_claims={"role": token_record.role}, )
[docs] async def authenticate(self, request: AuthRequest) -> SessionInfo | None: """Authenticate a request via JWT (networkless). Extracts the Bearer token from the Authorization header, decodes and verifies the JWT signature and expiry. Returns ``SessionInfo`` built from JWT claims, or ``None`` if the token is missing, invalid, or expired. """ auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): return None token = auth_header.removeprefix("Bearer ") payload = decode_token(self._config.jwt, token) if payload is None: return None extra = payload.extra or {} return SessionInfo( user_id=payload.sub, session_id=payload.session_id, role=extra.get("role", "default"), expires_at=payload.exp, metadata=extra, org_id=extra.get("org_id"), org_role=extra.get("org_role"), )
[docs] async def list_sessions( self, *, user_id: str | uuid.UUID | None = None, limit: int | None = None, offset: int | None = None, ) -> list[SessionInfo]: """List active (non-revoked) sessions ordered by creation date.""" q = ( self._db() .select(AuthSession) .where(~AuthSession.revoked) .order_by(AuthSession.created_at, asc=False) ) if user_id is not None: q = q.where(AuthSession.user_id == str(user_id)) if limit is not None: q = q.limit(limit) if offset is not None: q = q.offset(offset) sessions = await q.execute() return [ SessionInfo( user_id=str(s.user_id), session_id=str(s.session_id), role="default", expires_at=s.not_after, metadata={}, org_id=s.org_id, org_role=None, ) for s in sessions ]
[docs] async def sign_out(self, session_id: str | uuid.UUID) -> None: """Sign out by deleting all tokens for a session.""" await ( self._db() .delete(AuthSession) .where(AuthSession.session_id == str(session_id)) .execute() ) # Invalidate session cache if self._kv_client is not None: await self._kv_client.delete( f"{self._config.cache_prefix}:session:{session_id}".encode() )
[docs] async def sign_out_all(self, user_id: str | uuid.UUID) -> None: """Sign out all sessions for a user by deleting all tokens.""" session_ids = await ( self._db() .delete(AuthSession) .where(AuthSession.user_id == str(user_id)) .returning(AuthSession.session_id) .execute() ) # Invalidate all session caches if session_ids and self._kv_client is not None: cache_keys = [ f"{self._config.cache_prefix}:session:{sid}".encode() for sid in session_ids ] await self._kv_client.delete_many(cache_keys)
# ========================================================================= # Password Recovery # =========================================================================
[docs] async def request_password_recovery( self, *, email: str, recovery_url: str, recovery_subject: str = "Reset your password", **kwargs: Any, ) -> None: """Send a password recovery email. Does not reveal whether user exists for security. """ row = await ( self._db() .select(AuthUser.id, AuthUser.is_active) .from_(AuthUser) .where(AuthUser.email == email.lower()) .first_or_none() ) if not row: return # Don't reveal user doesn't exist uid, is_active = row if not is_active: return # Don't reveal user is disabled # Store recovery token in KV token = generate_secure_token() ttl = self._config.recovery_token_expire_minutes * 60 await self._kv().set( f"{self._config.cache_prefix}:recovery:{token}".encode(), str(uid).encode(), ttl=ttl, ) await self._email().send_email( subject=recovery_subject, to_email=email.lower(), template="recovery.html", recovery_url=f"{recovery_url}?token={token}", **kwargs, )
[docs] async def reset_password(self, token: str, new_password: str) -> UserInfo: """Reset password using recovery token. Raises: PasswordValidationError: New password did not meet requirements. InvalidTokenError: Recovery token is unknown, expired, used, or points at a missing user. """ # Validate password validation = validate_password(self._config.password, new_password) if not validation.valid: raise PasswordValidationError("; ".join(validation.errors)) # Look up recovery token in KV kv_key = f"{self._config.cache_prefix}:recovery:{token}".encode() user_id_bytes = await self._kv().get(kv_key) if user_id_bytes is None: raise InvalidTokenError("Recovery token is invalid or expired") # Delete token (single use) await self._kv().delete(kv_key) user = await self._fetch_user(user_id=user_id_bytes.decode()) if user is None: raise InvalidTokenError("Recovery token is invalid or expired") # Update password hashed_password = await self._hasher.async_hash(new_password) now = datetime.now(UTC) [result] = await ( self._db() .update(AuthUser) .set(encrypted_password=hashed_password, updated_at=now) .where(AuthUser.id == user.id) .returning(AuthUser) .execute() ) await self._invalidate_user_cache(user.id, user.email) # Sign out all sessions (security measure) await self.sign_out_all(user.id) return self._to_user_info(result)
# ========================================================================= # Email Confirmation # =========================================================================
[docs] async def confirm_email(self, token: str) -> UserInfo: """Confirm email address with token. Raises: InvalidTokenError: Confirmation token is unknown, expired, used, or points at a missing user. """ kv_key = f"{self._config.cache_prefix}:confirmation:{token}".encode() user_id_bytes = await self._kv().get(kv_key) if user_id_bytes is None: raise InvalidTokenError("Confirmation token is invalid or expired") # Delete token (single use) await self._kv().delete(kv_key) user = await self._fetch_user(user_id=user_id_bytes.decode()) if user is None: raise InvalidTokenError("Confirmation token is invalid or expired") # Confirm email now = datetime.now(UTC) [result] = await ( self._db() .update(AuthUser) .set(email_confirmed_at=now, updated_at=now) .where(AuthUser.id == user.id) .returning(AuthUser) .execute() ) await self._invalidate_user_cache(user.id, user.email) return self._to_user_info(result)
[docs] async def resend_confirmation_email( self, *, email: str, confirmation_url: str, confirmation_subject: str = "Confirm your email address", **kwargs: Any, ) -> None: """Resend email confirmation. Does not reveal whether user exists for security. """ row = await ( self._db() .select(AuthUser.id, AuthUser.email_confirmed_at) .from_(AuthUser) .where(AuthUser.email == email.lower()) .first_or_none() ) if not row: return uid, confirmed_at = row if confirmed_at: return # Already confirmed # Store new confirmation token in KV token = generate_secure_token() ttl = self._config.confirmation_token_expire_hours * 3600 await self._kv().set( f"{self._config.cache_prefix}:confirmation:{token}".encode(), str(uid).encode(), ttl=ttl, ) await self._email().send_email( subject=confirmation_subject, to_email=email.lower(), template="confirmation.html", confirmation_url=f"{confirmation_url}?token={token}", **kwargs, )
# ========================================================================= # Organizations # ========================================================================= def _to_org_info(self, org: AuthOrganization) -> OrgInfo: """Convert an AuthOrganization ORM model to a public OrgInfo.""" return OrgInfo( id=str(org.id), name=org.name, slug=org.slug, metadata=org.metadata or {}, created_at=org.created_at, updated_at=org.updated_at, ) def _to_org_member_info(self, member: AuthOrgMember) -> OrgMemberInfo: """Convert an AuthOrgMember ORM model to a public OrgMemberInfo.""" return OrgMemberInfo( org_id=str(member.org_id), user_id=str(member.user_id), role=member.role, created_at=member.created_at, updated_at=member.updated_at, )
[docs] async def create_org( self, *, name: str, slug: str, creator_id: str | uuid.UUID, **kwargs: Any, ) -> OrgInfo: """Create an organization. The creator is added as owner. Raises: OrgSlugConflictError: Slug is already taken. """ now = datetime.now(UTC) org = await ( self._db() .insert(AuthOrganization) .values(name=name, slug=slug, created_at=now, updated_at=now) .ignore_conflicts(target=AuthOrganization.slug) .returning(AuthOrganization) .execute() ) if org is None: raise OrgSlugConflictError(slug) # Add creator as owner await ( self._db() .insert(AuthOrgMember) .values( org_id=org.id, user_id=str(creator_id), role="owner", created_at=now, updated_at=now, ) .execute() ) return self._to_org_info(org)
async def _resolve_org_id( self, *, org_id: str | uuid.UUID | None, slug: str | None, ) -> str | None: """Translate (org_id, slug) → canonical id. Exactly one must be provided. Returns the canonical id, or ``None`` when ``slug`` was given but no matching org exists. Raises ``ValueError`` when neither or both are provided. """ if (org_id is None) == (slug is None): raise ValueError("Provide exactly one of org_id= or slug=") if slug is not None: org = await ( self._db() .select(AuthOrganization) .where(AuthOrganization.slug == slug) .first_or_none() ) return str(org.id) if org else None return str(org_id)
[docs] async def get_org( self, *, org_id: str | uuid.UUID | None = None, slug: str | None = None, ) -> OrgInfo: """Get an organization by ID or slug. Provide exactly one. Raises: OrgNotFoundError: No matching org. """ canonical = await self._resolve_org_id(org_id=org_id, slug=slug) if canonical is None: raise OrgNotFoundError(f"No org with slug {slug!r}") org = await ( self._db() .select(AuthOrganization) .where(AuthOrganization.id == canonical) .first_or_none() ) if org is None: raise OrgNotFoundError(f"No org with id {canonical!r}") return self._to_org_info(org)
[docs] async def update_org( self, *, org_id: str | uuid.UUID | None = None, org_slug: str | None = None, name: str | None = None, slug: str | None = None, **kwargs: Any, ) -> OrgInfo: """Update an organization. Identify by ``org_id`` or ``org_slug``. ``slug`` is the new slug to assign — not the lookup key. Raises: OrgNotFoundError: Org identifier did not resolve. OrgSlugConflictError: New slug collides with another org. """ canonical = await self._resolve_org_id(org_id=org_id, slug=org_slug) if canonical is None: raise OrgNotFoundError(f"No org with slug {org_slug!r}") existing = await ( self._db() .select(AuthOrganization) .where(AuthOrganization.id == canonical) .first_or_none() ) if existing is None: raise OrgNotFoundError(f"No org with id {canonical!r}") updates: dict[str, Any] = {"updated_at": datetime.now(UTC)} if name is not None: updates["name"] = name if slug is not None: updates["slug"] = slug try: [result] = await ( self._db() .update(AuthOrganization) .set(**updates) .where(AuthOrganization.id == canonical) .returning(AuthOrganization) .execute() ) except Exception as e: # Slug uniqueness is enforced at the column level; surface a # typed conflict only if the caller actually changed the slug. if slug is not None: raise OrgSlugConflictError(slug) from e raise return self._to_org_info(result)
[docs] async def delete_org( self, *, org_id: str | uuid.UUID | None = None, slug: str | None = None, ) -> bool: """Delete an organization and all its memberships.""" canonical = await self._resolve_org_id(org_id=org_id, slug=slug) if canonical is None: return False existing = await ( self._db() .select(AuthOrganization) .where(AuthOrganization.id == canonical) .first_or_none() ) if existing is None: return False await ( self._db() .delete(AuthOrganization) .where(AuthOrganization.id == canonical) .execute() ) return True
[docs] async def list_orgs( self, *, user_id: str | uuid.UUID | None = None, limit: int | None = None, offset: int | None = None, ) -> list[OrgInfo]: """List organizations, optionally filtered by user membership.""" q = ( self._db() .select(AuthOrganization) .order_by(AuthOrganization.created_at, asc=False) ) if user_id is not None: q = q.inner_join( AuthOrgMember, AuthOrgMember.org_id == AuthOrganization.id, ).where(AuthOrgMember.user_id == str(user_id)) if limit is not None: q = q.limit(limit) if offset is not None: q = q.offset(offset) return [self._to_org_info(o) for o in await q.execute()]
# ========================================================================= # Organization Membership # =========================================================================
[docs] async def add_org_member( self, *, org_id: str | uuid.UUID | None = None, slug: str | None = None, user_id: str | uuid.UUID, role: str = "member", ) -> OrgMemberInfo: """Add a user to an organization (identify by ``org_id`` or ``slug``). Raises: OrgNotFoundError: Org identifier did not resolve. MemberAlreadyExistsError: User is already a member. """ canonical = await self._resolve_org_id(org_id=org_id, slug=slug) if canonical is None: raise OrgNotFoundError(f"No org with slug {slug!r}") now = datetime.now(UTC) member = await ( self._db() .insert(AuthOrgMember) .values( org_id=canonical, user_id=str(user_id), role=role, created_at=now, updated_at=now, ) .ignore_conflicts(target=(AuthOrgMember.org_id, AuthOrgMember.user_id)) .returning(AuthOrgMember) .execute() ) if member is None: raise MemberAlreadyExistsError() return self._to_org_member_info(member)
[docs] async def update_org_member( self, *, org_id: str | uuid.UUID | None = None, slug: str | None = None, user_id: str | uuid.UUID, role: str, ) -> OrgMemberInfo: """Update a member's role. Raises: OrgNotFoundError: Org identifier did not resolve. OrgMemberNotFoundError: User is not a member of the org. """ canonical = await self._resolve_org_id(org_id=org_id, slug=slug) if canonical is None: raise OrgNotFoundError(f"No org with slug {slug!r}") existing = await ( self._db() .select(AuthOrgMember) .where(AuthOrgMember.org_id == canonical) .where(AuthOrgMember.user_id == str(user_id)) .first_or_none() ) if existing is None: raise OrgMemberNotFoundError( f"User {user_id!r} is not a member of org {canonical!r}" ) [result] = await ( self._db() .update(AuthOrgMember) .set(role=role, updated_at=datetime.now(UTC)) .where(AuthOrgMember.org_id == canonical) .where(AuthOrgMember.user_id == str(user_id)) .returning(AuthOrgMember) .execute() ) return self._to_org_member_info(result)
[docs] async def remove_org_member( self, *, org_id: str | uuid.UUID | None = None, slug: str | None = None, user_id: str | uuid.UUID, ) -> bool: """Remove a user from an organization. Returns ``False`` if the org or membership does not exist. Raises: LastOwnerError: Removing this member would leave the org without an owner. """ canonical = await self._resolve_org_id(org_id=org_id, slug=slug) if canonical is None: return False existing = await ( self._db() .select(AuthOrgMember) .where(AuthOrgMember.org_id == canonical) .where(AuthOrgMember.user_id == str(user_id)) .first_or_none() ) if existing is None: return False # Prevent removing the last owner if existing.role == "owner": owner_count = await ( self._db() .select(AuthOrgMember) .where(AuthOrgMember.org_id == canonical) .where(AuthOrgMember.role == "owner") .count() ) if owner_count <= 1: logger.error( "Remove org member failed: cannot remove last owner of org %s", canonical, ) raise LastOwnerError( f"Cannot remove the last owner of org {canonical!r}" ) await ( self._db() .delete(AuthOrgMember) .where(AuthOrgMember.org_id == canonical) .where(AuthOrgMember.user_id == str(user_id)) .execute() ) return True
[docs] async def list_org_members( self, *, org_id: str | uuid.UUID | None = None, slug: str | None = None, limit: int | None = None, offset: int | None = None, ) -> list[OrgMemberInfo]: """List members of an organization (identify by ``org_id`` or ``slug``). Raises: OrgNotFoundError: Org identifier did not resolve. """ canonical = await self._resolve_org_id(org_id=org_id, slug=slug) if canonical is None: raise OrgNotFoundError(f"No org with slug {slug!r}") q = ( self._db() .select(AuthOrgMember) .where(AuthOrgMember.org_id == canonical) .order_by(AuthOrgMember.created_at, asc=True) ) if limit is not None: q = q.limit(limit) if offset is not None: q = q.offset(offset) return [self._to_org_member_info(m) for m in await q.execute()]
[docs] async def get_org_member( self, *, org_id: str | uuid.UUID | None = None, slug: str | None = None, user_id: str | uuid.UUID, ) -> OrgMemberInfo: """Get a single membership record. Raises: OrgNotFoundError: Org identifier did not resolve. OrgMemberNotFoundError: User is not a member of the org. """ canonical = await self._resolve_org_id(org_id=org_id, slug=slug) if canonical is None: raise OrgNotFoundError(f"No org with slug {slug!r}") member = await ( self._db() .select(AuthOrgMember) .where(AuthOrgMember.org_id == canonical) .where(AuthOrgMember.user_id == str(user_id)) .first_or_none() ) if member is None: raise OrgMemberNotFoundError( f"User {user_id!r} is not a member of org {canonical!r}" ) return self._to_org_member_info(member)
# ========================================================================= # Organization Session Context # =========================================================================
[docs] async def set_active_org( self, *, session_id: str | uuid.UUID, org_id: str | uuid.UUID | None, ) -> TokenPair: """Switch the active organization for a session. Pass ``org_id=None`` to clear the active org. Raises: InvalidTokenError: Session id is unknown or already revoked. OrgMemberNotFoundError: User is not a member of the target org. """ # Find the active session session = await ( self._db() .select(AuthSession) .where(AuthSession.session_id == str(session_id)) .where(~AuthSession.revoked) .order_by(AuthSession.created_at, asc=False) .first_or_none() ) if session is None: logger.error("Set active org failed: session not found") raise InvalidTokenError("Session not found or revoked") extra_claims: dict[str, Any] = {"role": session.role} if org_id is not None: # Verify user is a member member = await ( self._db() .select(AuthOrgMember) .where(AuthOrgMember.org_id == str(org_id)) .where(AuthOrgMember.user_id == str(session.user_id)) .first_or_none() ) if member is None: raise OrgMemberNotFoundError( f"User {session.user_id!r} is not a member of org {org_id!r}" ) extra_claims["org_id"] = str(org_id) extra_claims["org_role"] = member.role # Update session's org_id await ( self._db() .update(AuthSession) .set(org_id=str(org_id)) .where(AuthSession.session_id == str(session_id)) .where(~AuthSession.revoked) .execute() ) else: # Clear org context await ( self._db() .update(AuthSession) .set(org_id=None) .where(AuthSession.session_id == str(session_id)) .where(~AuthSession.revoked) .execute() ) return create_token_pair( self._config.jwt, session.user_id, session.session_id, session.token, extra_claims=extra_claims, )