"""Column descriptor base class for Derp ORM."""
from __future__ import annotations
import dataclasses
import enum as enum_lib
from typing import Any, Literal, Self, overload
from derp.orm.expression_base import ComparisonOperator, Expression
[docs]
class FK(enum_lib.StrEnum):
"""Actions for foreign key ON DELETE / ON UPDATE clauses."""
CASCADE = "CASCADE"
SET_NULL = "SET NULL"
SET_DEFAULT = "SET DEFAULT"
RESTRICT = "RESTRICT"
[docs]
class Fn:
"""Predefined SQL functions for use as column defaults."""
[docs]
@staticmethod
def gen_random_uuid() -> str:
return "gen_random_uuid()"
[docs]
@staticmethod
def now() -> str:
return "now()"
[docs]
@staticmethod
def current_timestamp() -> str:
return "CURRENT_TIMESTAMP"
[docs]
@staticmethod
def to_tsvector(config: str, *columns: str) -> str:
"""Build a ``to_tsvector(config, col1 || ' ' || col2)`` expression."""
expr = " || ' ' || ".join(columns)
return f"to_tsvector('{config}', {expr})"
[docs]
class FieldSpec:
"""Column constraints returned by :func:`Field`.
This is a placeholder that ``Table.__init_subclass__`` replaces with a
real :class:`Column` descriptor after resolving the type annotation.
"""
__slots__ = (
"primary",
"unique",
"default",
"generated",
"foreign_key",
"on_delete",
"on_update",
)
[docs]
def __init__(
self,
*,
primary: bool = False,
unique: bool = False,
default: Any = dataclasses.MISSING,
generated: str | None = None,
foreign_key: str | Column[Any] | None = None,
on_delete: FK | None = None,
on_update: FK | None = None,
) -> None:
if generated is not None and default is not dataclasses.MISSING:
raise ValueError("A column cannot have both `default` and `generated`.")
self.primary = primary
self.unique = unique
self.default = default
self.generated = generated
self.foreign_key = foreign_key
self.on_delete = on_delete
self.on_update = on_update
[docs]
def Field(
*,
primary: bool = False,
unique: bool = False,
default: Any = dataclasses.MISSING,
generated: str | None = None,
foreign_key: str | Column[Any] | None = None,
on_delete: FK
| Literal["cascade", "set null", "set default", "restrict"]
| None = None,
on_update: FK
| Literal["cascade", "set null", "set default", "restrict"]
| None = None,
) -> Any:
"""Declare column constraints.
Foreign keys::
Field(foreign_key=User.id, on_delete="cascade")
Field(foreign_key="users.id")
Generated columns::
Field(generated="price * quantity")
"""
return FieldSpec(
primary=primary,
unique=unique,
default=default,
generated=generated,
foreign_key=foreign_key,
on_delete=(FK(on_delete.upper()) if isinstance(on_delete, str) else on_delete),
on_update=(FK(on_update.upper()) if isinstance(on_update, str) else on_update),
)
[docs]
class Column[T](Expression):
"""Base descriptor for all table columns.
Extends Expression so columns can be used directly in query building.
Implements the descriptor protocol for typed class/instance access.
Subclasses set ``_sql_type`` as a class variable. Parameterized types
(e.g., ``Varchar[255]``) override ``sql_type()`` to include parameters.
"""
_sql_type: str = ""
_primary: bool
_unique: bool
_nullable: bool
_default: Any
_generated: str | None
_foreign_key: str | Column[Any] | None
_on_delete: FK | None
_on_update: FK | None
_table_name: str | None
_field_name: str | None
[docs]
def __init__(self, spec: FieldSpec) -> None:
# Use object.__setattr__ to bypass the Column descriptor's
# __set__ which ty treats as governing all _-prefixed attrs.
object.__setattr__(self, "_primary", spec.primary)
object.__setattr__(self, "_unique", spec.unique)
object.__setattr__(self, "_nullable", False)
object.__setattr__(self, "_default", spec.default)
object.__setattr__(self, "_generated", spec.generated)
object.__setattr__(self, "_foreign_key", spec.foreign_key)
object.__setattr__(self, "_on_delete", spec.on_delete)
object.__setattr__(self, "_on_update", spec.on_update)
object.__setattr__(self, "_table_name", None)
object.__setattr__(self, "_field_name", None)
# -- Descriptor protocol --------------------------------------------------
@overload
def __get__(self, obj: None, owner: type) -> Self: ...
@overload
def __get__(self, obj: object, owner: type) -> T: ...
def __get__(self, obj: object | None, owner: type) -> Self | T:
if obj is None:
return self
return getattr(obj, f"_{self._field_name}")
def __set__(self, obj: object, value: T) -> None:
setattr(obj, f"_{self._field_name}", value)
def __set_name__(self, owner: Any, name: str) -> None:
self._field_name = name
# Table name is set later by Table.__init_subclass__
# -- Metadata accessors ---------------------------------------------------
@property
def primary_key(self) -> bool:
return self._primary
@property
def unique(self) -> bool:
return self._unique
@property
def nullable(self) -> bool:
return self._nullable
@property
def default(self) -> Any:
return self._default if self._default is not dataclasses.MISSING else None
@property
def has_default(self) -> bool:
return self._default is not dataclasses.MISSING
@property
def generated(self) -> str | None:
return self._generated
@property
def foreign_key(self) -> str | Column[Any] | None:
return self._foreign_key
@property
def on_delete(self) -> FK | None:
return self._on_delete
@property
def on_update(self) -> FK | None:
return self._on_update
[docs]
def foreign_key_sql(self) -> str | None:
"""Generate the REFERENCES clause, or None if no FK."""
if self._foreign_key is None:
return None
if isinstance(self._foreign_key, Column):
col: Column[Any] = self._foreign_key
if not col._table_name or not col._field_name:
raise ValueError(
"Column passed to foreign_key has no table metadata. "
"Use a class-level column reference like User.id."
)
ref = f"{col._table_name}({col._field_name})"
else:
ref = self._foreign_key.replace(".", "(") + ")"
sql = f"REFERENCES {ref}"
if self._on_delete:
sql += f" ON DELETE {self._on_delete}"
if self._on_update:
sql += f" ON UPDATE {self._on_update}"
return sql
[docs]
def sql_type(self) -> str:
return self._sql_type
[docs]
def is_auto_increment(self) -> bool:
return self._sql_type in ("SERIAL", "BIGSERIAL")
# -- Expression interface -------------------------------------------------
[docs]
def to_sql(self, params: list[Any]) -> str:
if self._table_name and self._field_name:
return f"{self._table_name}.{self._field_name}"
if self._field_name:
return self._field_name
raise ValueError("Column missing table/field name metadata")
# -- Comparison operators (supplement Expression's dunders) ----------------
def __invert__(self) -> Any:
"""Bitwise NOT (~col) — produces ``col = FALSE`` for boolean columns."""
from derp.orm.query.expressions import BinaryOp, to_expr
return BinaryOp(self, ComparisonOperator.EQ, to_expr(False))
# -- Aggregate methods (not on Expression base) ---------------------------
[docs]
def count(self) -> Any:
from derp.orm.query.expressions import AggregateFunc
return AggregateFunc("COUNT", self)
[docs]
def sum(self) -> Any:
from derp.orm.query.expressions import AggregateFunc
return AggregateFunc("SUM", self)
[docs]
def avg(self) -> Any:
from derp.orm.query.expressions import AggregateFunc
return AggregateFunc("AVG", self)
[docs]
def min(self) -> Any:
from derp.orm.query.expressions import AggregateFunc
return AggregateFunc("MIN", self)
[docs]
def max(self) -> Any:
from derp.orm.query.expressions import AggregateFunc
return AggregateFunc("MAX", self)
[docs]
def case(self, mapping: dict[Any, Any], *, else_: Any | None = None) -> Any:
from derp.orm.query.expressions import CaseExpression
return CaseExpression(self, list(mapping.items()), else_value=else_)