Source code for derp.orm.table

"""Table base class for Derp ORM using Column descriptors."""

from __future__ import annotations

import copy
import enum as enum_lib
import json
import types as pytypes
from collections.abc import Sequence
from typing import Any, ClassVar, Self, dataclass_transform, get_args, get_origin

from derp.orm.column.base import Column, Field, FieldSpec
from derp.orm.index import Index


def _unwrap_nullable(ann: Any) -> tuple[Any, bool]:
    """Unwrap ``SomeType | None`` → ``(SomeType, True)``."""
    origin = get_origin(ann)
    if origin is pytypes.UnionType:
        args = [a for a in get_args(ann) if a is not type(None)]
        if len(args) == 1:
            return args[0], True
    return ann, False


def _resolve_sql_type(col_type: type[Column[Any]]) -> str:
    """Get the SQL type string from a PG type class, handling enums."""
    sql = getattr(col_type, "_sql_type", "")
    if sql:
        return sql

    # Check if this is Column[SomeEnum] — derive from enum class
    type_args = get_args(col_type)
    if type_args:
        arg = type_args[0]
        if isinstance(arg, type) and issubclass(arg, enum_lib.Enum):
            from derp.orm.column.types import _enum_sql_name

            return _enum_sql_name(arg)

    return ""


[docs] @dataclass_transform(kw_only_default=True, field_specifiers=(Field,)) class Table: """Base class for all Derp table definitions. Example:: class User(Table, table="users"): id: Serial = Field(primary=True) name: Text = Field() email: Varchar[255] = Field(unique=True) # Query building — direct class access: db.select(User).where(User.name == "Alice") """ __table_name__: ClassVar[str] __explicit_table__: ClassVar[bool] __columns__: ClassVar[dict[str, Column[Any]]] _resolved_indexes: ClassVar[list[Index]]
[docs] @classmethod def indexes(cls) -> Sequence[Index]: """Override to define indexes for this table.""" return []
def __init_subclass__(cls, table: str | None = None, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) # Set table name if table is not None: cls.__table_name__ = table cls.__explicit_table__ = True # Enforce: if parent has __explicit_table__, child must use same for base in cls.__mro__[1:]: parent_table = getattr(base, "__table_name__", None) parent_explicit = getattr(base, "__explicit_table__", False) if ( parent_explicit and parent_table is not None and table != parent_table ): raise TypeError( f"Table '{cls.__name__}' uses table name '{table}' " f"but its parent '{base.__name__}' uses " f"'{parent_table}'. Inherited tables must use the " f"same table name as their parent." ) elif not hasattr( cls, "__table_name__" ) or cls.__table_name__ is Table.__dict__.get("__table_name__"): cls.__table_name__ = cls.__name__.lower() cls.__explicit_table__ = False table_name = cls.__table_name__ # Resolve type annotations (evaluates "Varchar[255]" strings etc.) hints = _get_type_hints_safe(cls) # Collect columns: inherited + own columns: dict[str, Column[Any]] = {} # Inherited columns (clone with this class's table name) for base in reversed(cls.__mro__[1:]): base_columns = getattr(base, "__columns__", None) if base_columns is not None: for name, col in base_columns.items(): if name not in cls.__dict__: clone = copy.copy(col) clone._table_name = table_name clone._field_name = name setattr(cls, name, clone) columns[name] = clone # Own columns: FieldSpec in class dict → resolve annotation → Column for name in list(cls.__dict__): attr = cls.__dict__[name] if not isinstance(attr, FieldSpec): continue ann_type = hints.get(name) if ann_type is None: raise TypeError( f"{cls.__name__}.{name}: has Field() but no type annotation" ) # Nullable[X] sets the column to nullable is_nullable = getattr(ann_type, "_nullable_marker", False) # Construct Column from PG type class + FieldSpec if isinstance(ann_type, type) and issubclass(ann_type, Column): col = ann_type(attr) else: # Fallback: bare Column with resolved SQL type col = Column(attr) col._sql_type = _resolve_sql_type(ann_type) or "" if is_nullable: col._nullable = True if not col.has_default: col._default = None col._table_name = table_name col._field_name = name setattr(cls, name, col) columns[name] = col cls.__columns__ = columns # Precompute col_name → "_col_name" for fast hydration cls.__slot_map__ = {name: f"_{name}" for name in columns} # Validate nullable annotations cls._validate_nullable_annotations(hints) # Resolve indexes from the indexes() classmethod. cls._resolved_indexes = list(cls.indexes())
[docs] def __init__(self, **kwargs: Any) -> None: columns = type(self).__columns__ for name, value in kwargs.items(): if name not in columns: raise TypeError( f"{type(self).__name__}() got an unexpected keyword " f"argument '{name}'" ) setattr(self, name, value) # Handle defaults for missing fields missing: list[str] = [] for name, col in columns.items(): if name not in kwargs: if col.has_default: setattr(self, name, col.default) elif col.nullable: setattr(self, name, None) else: missing.append(name) if missing: raise TypeError( f"{type(self).__name__}() missing required keyword " f"arguments: {', '.join(repr(n) for n in missing)}" )
@classmethod def _from_row(cls, data: dict[str, Any] | Any) -> Self: """Fast-path hydration from a database row. Bypasses ``__init__`` validation and descriptor ``__set__``. Uses precomputed slot map and ``object.__setattr__`` directly. *data* can be a ``dict`` or an ``asyncpg.Record`` — both support ``key in data`` and ``data[key]``. """ obj = object.__new__(cls) sa = object.__setattr__ slot_map = cls.__slot_map__ for col_name, attr_name in slot_map.items(): if col_name in data: sa(obj, attr_name, data[col_name]) return obj @classmethod def _validate_nullable_annotations( cls, hints: dict[str, Any] | None = None ) -> None: """Ensure nullable columns use ``Nullable[X]`` and vice versa.""" if hints is None: try: hints = _get_type_hints_safe(cls) except Exception: return for name, col in cls.__columns__.items(): if name not in hints: continue ann = hints[name] is_nullable_ann = getattr(ann, "_nullable_marker", False) # Reject | None syntax — must use Nullable[X] _, has_union_none = _unwrap_nullable(ann) if has_union_none: raise TypeError( f"{cls.__name__}.{name}: use Nullable[X] instead " f"of 'X | None' for nullable columns" ) if col.nullable and not is_nullable_ann: ann_str = getattr(ann, "__name__", None) or str(ann) raise TypeError( f"{cls.__name__}.{name}: column is nullable but " f"annotation {ann_str!r} is not Nullable[...]" )
[docs] @classmethod def get_table_name(cls) -> str: """Get the SQL table name.""" return cls.__table_name__
[docs] @classmethod def get_columns(cls) -> dict[str, Column[Any]]: """Get all column definitions.""" return getattr(cls, "__columns__", {})
[docs] @classmethod def get_primary_key(cls) -> tuple[str, Column[Any]] | None: """Get the primary key column if any.""" for name, col in cls.get_columns().items(): if col.primary_key: return (name, col) return None
[docs] def to_dict(self) -> dict[str, Any]: """Serialize instance to a dict.""" ga = object.__getattribute__ result: dict[str, Any] = {} for name, attr_name in type(self).__slot_map__.items(): result[name] = ga(self, attr_name) return result
[docs] def to_json(self) -> str: """Serialize instance to a JSON string.""" return json.dumps(self.to_dict(), default=_json_default)
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> Self: """Construct an instance from a dict (ignores unknown keys).""" columns = cls.__columns__ filtered = {k: v for k, v in data.items() if k in columns} return cls(**filtered)
[docs] @classmethod def from_json(cls, data: str | bytes) -> Self: """Construct an instance from a JSON string.""" return cls.from_dict(json.loads(data))
[docs] @classmethod def to_ddl(cls) -> str: """Generate CREATE TABLE DDL statement.""" columns = cls.get_columns() if not columns: raise ValueError(f"Table {cls.__name__} has no columns defined") table_name = cls.get_table_name() column_defs: list[str] = [] constraints: list[str] = [] indexes: list[str] = [] for col_name, col in columns.items(): col_def = f" {col_name} {col.sql_type()}" if col.primary_key: col_def += " PRIMARY KEY" if not col.nullable and not col.primary_key: col_def += " NOT NULL" if col.unique and not col.primary_key: col_def += " UNIQUE" if col.generated is not None: col_def += f" GENERATED ALWAYS AS ({col.generated}) STORED" elif col.default is not None: default = col.default if isinstance(default, str) and ( default.endswith("()") or default.upper() in ("CURRENT_TIMESTAMP", "TRUE", "FALSE") ): col_def += f" DEFAULT {default}" elif isinstance(default, bool): col_def += f" DEFAULT {str(default).upper()}" elif isinstance(default, (int, float)): col_def += f" DEFAULT {default}" else: escaped = str(default).replace("'", "''") col_def += f" DEFAULT '{escaped}'" column_defs.append(col_def) # Foreign key constraints fk_sql = col.foreign_key_sql() if fk_sql: constraints.append(f" FOREIGN KEY ({col_name}) {fk_sql}") # Indexes for idx in cls._resolved_indexes: indexes.append(idx.to_ddl(table_name) + ";") all_defs = column_defs + constraints ddl = f"CREATE TABLE {table_name} (\n" ddl += ",\n".join(all_defs) ddl += "\n);" if indexes: ddl += "\n\n" + "\n\n".join(indexes) return ddl
def get_column_ref(table: type[Table], column_name: str) -> Column[Any]: """Get a column reference for query building.""" columns = table.get_columns() if column_name not in columns: raise ValueError(f"Column {column_name} not found in table {table.__name__}") return columns[column_name] def _json_default(obj: Any) -> Any: """JSON serializer for types not natively supported.""" import datetime import uuid if isinstance(obj, uuid.UUID): return obj.hex if isinstance(obj, datetime.datetime | datetime.date | datetime.time): return obj.isoformat() if isinstance(obj, datetime.timedelta): return obj.total_seconds() if hasattr(obj, "value"): # Enum return obj.value raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") def _get_type_hints_safe(cls: type) -> dict[str, Any]: """Get type hints for column fields only. Instead of ``typing.get_type_hints`` (which tries to evaluate ALL annotations including ClassVar), we manually evaluate only the annotations that have a corresponding ``FieldSpec`` in the class dict. """ import sys # Build namespace for eval ns: dict[str, Any] = {} # Include typing module for Union, ClassVar etc. import typing as _typing ns.update(vars(_typing)) # Include the column types module from derp.orm.column import types as _col_types ns.update(vars(_col_types)) # Include the column base module (for Column itself) from derp.orm.column import base as _col_base ns.update(vars(_col_base)) # Include the module where the class is defined module = sys.modules.get(cls.__module__) if module is not None: ns.update(vars(module)) result: dict[str, Any] = {} # Walk MRO to collect annotations for klass in reversed(cls.__mro__): for name, ann in getattr(klass, "__annotations__", {}).items(): # Only resolve annotations that have a FieldSpec if isinstance(cls.__dict__.get(name), FieldSpec): if isinstance(ann, str): try: result[name] = eval(ann, ns) # noqa: S307 except Exception: pass else: result[name] = ann return result