Source code for derp.auth.jwt

"""JWT token creation and validation."""

from __future__ import annotations

import uuid
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any

import jwt

from derp.config import JWTConfig


[docs] @dataclass(kw_only=True) class TokenPayload: """Payload data from a decoded JWT token.""" sub: str # User ID session_id: str exp: datetime iat: datetime iss: str | None = None aud: str | None = None extra: dict[str, Any] | None = None
[docs] @dataclass(kw_only=True) class TokenPair: """Access and refresh token pair.""" access_token: str refresh_token: str token_type: str = "bearer" expires_in: int = 0 # Access token expiry in seconds expires_at: datetime
def create_access_token( config: JWTConfig, user_id: str | uuid.UUID, session_id: str | uuid.UUID, extra_claims: dict[str, Any] | None = None, ) -> str: """Create a short-lived JWT access token. Args: user_id: The user's unique identifier session_id: The session's unique identifier extra_claims: Additional claims to include in the token Returns: Encoded JWT token string """ now = datetime.now(UTC) expires = now + timedelta(minutes=config.access_token_expire_minutes) payload: dict[str, Any] = { "sub": str(user_id), "session_id": str(session_id), "iat": now, "exp": expires, } if config.issuer: payload["iss"] = config.issuer if config.audience: payload["aud"] = config.audience if extra_claims: reserved = {"sub", "session_id", "iat", "exp", "iss", "aud"} conflicts = reserved & extra_claims.keys() if conflicts: raise ValueError( "extra_claims cannot override reserved JWT claims: " f"{', '.join(sorted(conflicts))}" ) payload.update(extra_claims) return jwt.encode( payload, config.secret, algorithm=config.algorithm, ) def decode_token(config: JWTConfig, token: str) -> TokenPayload | None: """Decode and validate a JWT token. Args: token: The JWT token string Returns: TokenPayload with decoded data, or None if the token is expired or invalid. """ try: payload = jwt.decode( token, config.secret, algorithms=[config.algorithm], audience=config.audience, issuer=config.issuer, options={"require": ["aud"]} if config.audience else None, ) except jwt.InvalidTokenError: return None # Extract known fields sub = payload.get("sub") session_id = payload.get("session_id") exp = payload.get("exp") iat = payload.get("iat") if not sub or not session_id: return None # Convert timestamps exp_dt = datetime.fromtimestamp(exp, tz=UTC) if exp else datetime.now(UTC) iat_dt = datetime.fromtimestamp(iat, tz=UTC) if iat else datetime.now(UTC) # Extract extra claims known_keys = {"sub", "session_id", "exp", "iat", "iss", "aud"} extra = {k: v for k, v in payload.items() if k not in known_keys} iss = payload.get("iss") aud = payload.get("aud") return TokenPayload( sub=sub, session_id=session_id, exp=exp_dt, iat=iat_dt, iss=iss, aud=aud, extra=extra if extra else None, ) def create_token_pair( config: JWTConfig, user_id: str | uuid.UUID, session_id: str | uuid.UUID, refresh_token: str, extra_claims: dict[str, Any] | None = None, ) -> TokenPair: """Create a token pair with access token and refresh token. Args: user_id: The user's unique identifier session_id: The session's unique identifier refresh_token: The refresh token string (created separately) extra_claims: Additional claims for the access token Returns: TokenPair with both tokens """ access_token = create_access_token(config, user_id, session_id, extra_claims) expires_in = config.access_token_expire_minutes * 60 expires_at = datetime.now(UTC) + timedelta( minutes=config.access_token_expire_minutes ) return TokenPair( access_token=access_token, refresh_token=refresh_token, token_type="bearer", expires_in=expires_in, expires_at=expires_at, )