"""Query builder for select, insert, update, delete operations."""
from __future__ import annotations
import hashlib
import json
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Self, overload
import asyncpg
from derp.kv.base import KVClient
from derp.orm.column.base import Column
from derp.orm.index import SortOrder
from derp.orm.query.expressions import (
ColumnRef,
ExistsExpr,
Expression,
RawSQL,
SubqueryExpr,
_renumber_params,
)
from derp.orm.router import ReplicaRouter
from derp.orm.table import Table
if TYPE_CHECKING:
from derp.orm.query.returning import (
RMT2,
RMT3,
RMT4,
RMT5,
RMT6,
RMT7,
RMT8,
RMT9,
RMT10,
ROT2,
ROT3,
ROT4,
ROT5,
ROT6,
ROT7,
ROT8,
ROT9,
ROT10,
ROTO2,
ROTO3,
ROTO4,
ROTO5,
ROTO6,
ROTO7,
ROTO8,
ROTO9,
ROTO10,
)
@asynccontextmanager
async def _acquire(
pool_or_conn: asyncpg.Pool | asyncpg.Connection,
) -> AsyncIterator[asyncpg.Connection]:
"""Resolve a connection: acquire from pool or use directly."""
if isinstance(pool_or_conn, asyncpg.Pool):
async with pool_or_conn.acquire() as conn:
yield conn
else:
yield pool_or_conn
# PostgreSQL supports max 65535 parameters per query. Keep a safety
# margin so callers never need to think about the limit.
_PG_MAX_PARAMS = 65535
[docs]
class JoinType(StrEnum):
"""SQL JOIN types."""
INNER = "INNER"
LEFT = "LEFT"
RIGHT = "RIGHT"
FULL = "FULL OUTER"
CROSS = "CROSS"
@dataclass
class JoinClause:
"""Represents a JOIN clause."""
join_type: JoinType
table: type[Table]
condition: Expression | None
@dataclass
class OrderByClause:
"""Represents an ORDER BY clause."""
column: Column[Any] | str
direction: SortOrder = SortOrder.ASC
[docs]
class LockMode(StrEnum):
"""SQL row-level locking modes."""
UPDATE = "FOR UPDATE"
NO_KEY_UPDATE = "FOR NO KEY UPDATE"
SHARE = "FOR SHARE"
KEY_SHARE = "FOR KEY SHARE"
@dataclass
class LockClause:
"""Represents a row-level lock clause."""
mode: LockMode
nowait: bool = False
skip_locked: bool = False
@dataclass
class OnConflictClause:
"""Represents an ON CONFLICT clause for upsert."""
target: tuple[str, ...]
action: str # "nothing" or "update"
set_values: dict[str, Any] | None = None
class _WhereShorthandMixin:
"""Shorthand filter methods that accept string column names or Column."""
def where(self, cond: Expression) -> Self:
"""Add WHERE clause. Implemented by subclasses."""
raise NotImplementedError
def _resolve_column(self, column: Column[Any] | str) -> Column[Any] | ColumnRef:
"""Resolve a column reference from a string or Column."""
if isinstance(column, Column):
return column
table_name: str | None = None
from_table = getattr(self, "_from_table", None)
table = getattr(self, "_table", None)
if from_table is not None:
table_name = (
from_table
if isinstance(from_table, str)
else from_table.get_table_name()
)
elif table is not None:
table_name = table if isinstance(table, str) else table.get_table_name()
if "." in column:
t, c = column.split(".", 1)
return ColumnRef(t, c)
if table_name:
return ColumnRef(table_name, column)
raise ValueError(
f"Cannot resolve column '{column}': no table context. "
"Use 'table.column' format or set a FROM table."
)
def not_(self, column: Column[Any] | str) -> Self:
"""WHERE column == FALSE."""
return self.where(~self._resolve_column(column))
def eq(self, column: Column[Any] | str, value: Any) -> Self:
"""WHERE column = value."""
return self.where(self._resolve_column(column) == value)
def neq(self, column: Column[Any] | str, value: Any) -> Self:
"""WHERE column <> value."""
return self.where(self._resolve_column(column) != value)
def gt(self, column: Column[Any] | str, value: Any) -> Self:
"""WHERE column > value."""
return self.where(self._resolve_column(column) > value)
def gte(self, column: Column[Any] | str, value: Any) -> Self:
"""WHERE column >= value."""
return self.where(self._resolve_column(column) >= value)
def lt(self, column: Column[Any] | str, value: Any) -> Self:
"""WHERE column < value."""
return self.where(self._resolve_column(column) < value)
def lte(self, column: Column[Any] | str, value: Any) -> Self:
"""WHERE column <= value."""
return self.where(self._resolve_column(column) <= value)
def is_null(self, column: Column[Any] | str) -> Self:
"""WHERE column IS NULL."""
return self.where(self._resolve_column(column).is_null())
def is_not_null(self, column: Column[Any] | str) -> Self:
"""WHERE column IS NOT NULL."""
return self.where(self._resolve_column(column).is_not_null())
def in_(self, column: Column[Any] | str, values: Sequence[Any]) -> Self:
"""WHERE column IN (values)."""
return self.where(self._resolve_column(column).in_(values))
def not_in(self, column: Column[Any] | str, values: Sequence[Any]) -> Self:
"""WHERE column NOT IN (values)."""
return self.where(self._resolve_column(column).not_in(values))
def like(self, column: Column[Any] | str, pattern: str) -> Self:
"""WHERE column LIKE pattern."""
return self.where(self._resolve_column(column).like(pattern))
def ilike(self, column: Column[Any] | str, pattern: str) -> Self:
"""WHERE column ILIKE pattern."""
return self.where(self._resolve_column(column).ilike(pattern))
def between(self, column: Column[Any] | str, low: Any, high: Any) -> Self:
"""WHERE column BETWEEN low AND high."""
return self.where(self._resolve_column(column).between(low, high))
# =============================================================================
# SELECT Query
# =============================================================================
[docs]
class SelectQuery[T](_WhereShorthandMixin):
"""SELECT query - T is the result element type (Table subclass or dict)."""
[docs]
def __init__(
self,
pool: asyncpg.Pool | asyncpg.Connection | None,
columns: tuple[type[Table] | Column[Any] | Expression, ...],
*,
cache_store: KVClient | None = None,
router: ReplicaRouter | None = None,
):
self._pool = pool
self._columns = columns
self._from_table: type[Table] | str | SubqueryExpr | None = None
self._ctes: list[tuple[str, SelectQuery[Any]]] = []
self._joins: list[JoinClause] = []
self._where_clause: Expression | None = None
self._order_by: list[OrderByClause] = []
self._limit_value: int | None = None
self._offset_value: int | None = None
self._group_by: list[Column[Any] | str] = []
self._having_clause: Expression | None = None
self._distinct: bool = False
self._distinct_on: list[Column[Any]] = []
self._lock: LockClause | None = None
self._cache_store: KVClient | None = cache_store
self._cache_ttl: float | None = None
self._cache_lock_ttl: float | None = None
self._cache_retry_delay: float | None = None
self._router: ReplicaRouter | None = router
self._force_primary: bool = False
# Infer from table if first column is a Table class
if columns and isinstance(columns[0], type) and issubclass(columns[0], Table):
self._from_table = columns[0]
[docs]
def from_(self, table: type[Table] | str | SubqueryExpr) -> Self:
"""Set the FROM table. Accepts a Table class, string, or subquery."""
self._from_table = table
return self
[docs]
def where(self, cond: Expression) -> Self:
"""Add WHERE clause. Multiple calls combine with AND."""
if self._where_clause is not None:
self._where_clause = self._where_clause & cond
else:
self._where_clause = cond
return self
[docs]
def inner_join(self, table: type[Table], condition: Expression) -> Self:
"""Add INNER JOIN."""
self._joins.append(JoinClause(JoinType.INNER, table, condition))
return self
[docs]
def left_join(self, table: type[Table], condition: Expression) -> Self:
"""Add LEFT JOIN."""
self._joins.append(JoinClause(JoinType.LEFT, table, condition))
return self
[docs]
def right_join(self, table: type[Table], condition: Expression) -> Self:
"""Add RIGHT JOIN."""
self._joins.append(JoinClause(JoinType.RIGHT, table, condition))
return self
[docs]
def full_join(self, table: type[Table], condition: Expression) -> Self:
"""Add FULL OUTER JOIN."""
self._joins.append(JoinClause(JoinType.FULL, table, condition))
return self
[docs]
def cross_join(self, table: type[Table]) -> Self:
"""Add CROSS JOIN."""
self._joins.append(JoinClause(JoinType.CROSS, table, None))
return self
[docs]
def order_by(self, column: Column[Any] | str, *, asc: bool = True) -> Self:
"""Add ORDER BY clause."""
self._order_by.append(
OrderByClause(column, SortOrder.ASC if asc else SortOrder.DESC)
)
return self
[docs]
def limit(self, n: int) -> Self:
"""Add LIMIT clause."""
self._limit_value = n
return self
[docs]
def offset(self, n: int) -> Self:
"""Add OFFSET clause."""
self._offset_value = n
return self
[docs]
def group_by(self, *columns: Column[Any] | str) -> Self:
"""Add GROUP BY clause."""
self._group_by.extend(columns)
return self
[docs]
def having(self, cond: Expression) -> Self:
"""Add HAVING clause. Multiple calls combine with AND."""
if self._having_clause is not None:
self._having_clause = self._having_clause & cond
else:
self._having_clause = cond
return self
[docs]
def cache(
self,
ttl: float,
*,
lock_ttl: float | None = None,
retry_delay: float | None = None,
) -> Self:
"""Cache this query's results for ``ttl`` seconds."""
self._cache_ttl = ttl
self._cache_lock_ttl = lock_ttl
self._cache_retry_delay = retry_delay
return self
[docs]
def use_primary(self) -> Self:
"""Force this query to run against the primary database."""
self._force_primary = True
return self
[docs]
def distinct(self) -> Self:
"""Add DISTINCT to SELECT."""
self._distinct = True
return self
[docs]
def distinct_on(self, *columns: Column[Any]) -> Self:
"""Add DISTINCT ON to SELECT (PostgreSQL-specific)."""
self._distinct_on.extend(columns)
return self
[docs]
def for_update(self, *, nowait: bool = False, skip_locked: bool = False) -> Self:
"""Add FOR UPDATE row lock."""
self._lock = LockClause(LockMode.UPDATE, nowait=nowait, skip_locked=skip_locked)
return self
[docs]
def for_share(self, *, nowait: bool = False, skip_locked: bool = False) -> Self:
"""Add FOR SHARE row lock."""
self._lock = LockClause(LockMode.SHARE, nowait=nowait, skip_locked=skip_locked)
return self
[docs]
def as_(self, alias: str) -> SubqueryExpr:
"""Wrap this query as a subquery expression with an alias."""
return SubqueryExpr(self, _alias=alias)
[docs]
def exists(self) -> ExistsExpr:
"""Wrap this query as an EXISTS expression."""
return ExistsExpr(SubqueryExpr(self))
[docs]
def with_cte(self, name: str, query: SelectQuery[Any]) -> Self:
"""Add a Common Table Expression (WITH clause)."""
self._ctes.append((name, query))
return self
[docs]
def union(self, other: SelectQuery[Any]) -> SetOperationQuery[T]:
"""Combine with another query using UNION."""
return SetOperationQuery(self, "UNION", other)
[docs]
def union_all(self, other: SelectQuery[Any]) -> SetOperationQuery[T]:
"""Combine with another query using UNION ALL."""
return SetOperationQuery(self, "UNION ALL", other)
[docs]
def intersect(self, other: SelectQuery[Any]) -> SetOperationQuery[T]:
"""Combine with another query using INTERSECT."""
return SetOperationQuery(self, "INTERSECT", other)
[docs]
def except_(self, other: SelectQuery[Any]) -> SetOperationQuery[T]:
"""Combine with another query using EXCEPT."""
return SetOperationQuery(self, "EXCEPT", other)
[docs]
def build(self) -> tuple[str, list[Any]]:
"""Build the SQL query and parameters."""
params: list[Any] = []
# CTE (WITH) clause
cte_prefix = ""
if self._ctes:
cte_parts = []
for cte_name, cte_query in self._ctes:
cte_sql, cte_params = cte_query.build()
offset = len(params)
params.extend(cte_params)
renumbered = _renumber_params(cte_sql, offset)
cte_parts.append(f"{cte_name} AS ({renumbered})")
cte_prefix = f"WITH {', '.join(cte_parts)} "
# SELECT clause
select_parts: list[str] = []
for col in self._columns:
if isinstance(col, type) and issubclass(col, Table):
table_name = col.get_table_name()
select_parts.append(f"{table_name}.*")
elif isinstance(col, Expression):
select_parts.append(col.to_sql(params))
elif isinstance(col, Column):
if col._table_name and col._field_name:
select_parts.append(f"{col._table_name}.{col._field_name}")
elif col._field_name:
select_parts.append(col._field_name)
else:
select_parts.append(str(col))
# DISTINCT / DISTINCT ON
distinct_prefix = ""
if self._distinct_on:
on_parts = []
for dc in self._distinct_on:
if dc._table_name and dc._field_name:
on_parts.append(f"{dc._table_name}.{dc._field_name}")
elif dc._field_name:
on_parts.append(dc._field_name)
distinct_prefix = f"DISTINCT ON ({', '.join(on_parts)}) "
elif self._distinct:
distinct_prefix = "DISTINCT "
sql = f"{cte_prefix}SELECT {distinct_prefix}{', '.join(select_parts)}"
# FROM clause
if self._from_table is not None:
if isinstance(self._from_table, SubqueryExpr):
sql += f" FROM {self._from_table.to_sql(params)}"
elif isinstance(self._from_table, str):
sql += f" FROM {self._from_table}"
else:
sql += f" FROM {self._from_table.get_table_name()}"
# JOIN clauses
for join in self._joins:
join_table = join.table.get_table_name()
if join.join_type == JoinType.CROSS or join.condition is None:
sql += f" {join.join_type} JOIN {join_table}"
else:
condition_sql = join.condition.to_sql(params)
sql += f" {join.join_type} JOIN {join_table} ON {condition_sql}"
# WHERE clause
if self._where_clause:
where_sql = self._where_clause.to_sql(params)
sql += f" WHERE {where_sql}"
# GROUP BY clause
if self._group_by:
group_parts = []
for col in self._group_by:
if isinstance(col, Column) and col._table_name and col._field_name:
group_parts.append(f"{col._table_name}.{col._field_name}")
elif isinstance(col, Column) and col._field_name:
group_parts.append(col._field_name)
else:
group_parts.append(str(col))
sql += f" GROUP BY {', '.join(group_parts)}"
# HAVING clause
if self._having_clause is not None:
having_sql = self._having_clause.to_sql(params)
sql += f" HAVING {having_sql}"
# ORDER BY clause
if self._order_by:
order_parts = []
for ob in self._order_by:
if (
isinstance(ob.column, Column)
and ob.column._table_name
and ob.column._field_name
):
order_parts.append(
f"{ob.column._table_name}.{ob.column._field_name} "
f"{ob.direction}"
)
elif isinstance(ob.column, Column) and ob.column._field_name:
order_parts.append(f"{ob.column._field_name} {ob.direction}")
else:
order_parts.append(f"{ob.column} {ob.direction}")
sql += f" ORDER BY {', '.join(order_parts)}"
# LIMIT/OFFSET
if self._limit_value is not None:
sql += f" LIMIT {self._limit_value}"
if self._offset_value is not None:
sql += f" OFFSET {self._offset_value}"
# Row locking
if self._lock is not None:
sql += f" {self._lock.mode}"
if self._lock.nowait:
sql += " NOWAIT"
elif self._lock.skip_locked:
sql += " SKIP LOCKED"
return sql, params
[docs]
def build_count(self) -> tuple[str, list[Any]]:
"""Build a COUNT(*) SQL query and parameters."""
params: list[Any] = []
sql = "SELECT COUNT(*)"
if self._from_table is not None:
if isinstance(self._from_table, SubqueryExpr):
sql += f" FROM {self._from_table.to_sql(params)}"
elif isinstance(self._from_table, str):
sql += f" FROM {self._from_table}"
else:
sql += f" FROM {self._from_table.get_table_name()}"
for join in self._joins:
join_table = join.table.get_table_name()
if join.join_type == JoinType.CROSS or join.condition is None:
sql += f" {join.join_type} JOIN {join_table}"
else:
condition_sql = join.condition.to_sql(params)
sql += f" {join.join_type} JOIN {join_table} ON {condition_sql}"
if self._where_clause:
where_sql = self._where_clause.to_sql(params)
sql += f" WHERE {where_sql}"
return sql, params
def _cache_key(self, sql: str, params: list[Any]) -> str:
"""Derive a cache key from SQL and parameters."""
raw = sql + json.dumps(params, default=str)
digest = hashlib.sha256(raw.encode()).hexdigest()
return f"derp:query:{digest}"
def _effective_pool(self) -> asyncpg.Pool | asyncpg.Connection:
"""Return the pool to use, considering the replica router."""
if self._pool is None:
raise RuntimeError("No database connection. Call db.connect() first.")
if (
self._router is not None
and not self._force_primary
and isinstance(self._pool, asyncpg.Pool)
):
return self._router.get_read_pool()
return self._pool
[docs]
async def execute(self) -> list[T]:
"""Execute the query and return results."""
pool = self._effective_pool()
sql, params = self.build()
if self._cache_store is not None and self._cache_ttl is not None:
cache_key = self._cache_key(sql, params).encode()
async def _compute() -> bytes:
async with _acquire(pool) as conn:
rows = await conn.fetch(sql, *params)
return json.dumps(self._rows_to_dicts(rows), default=str).encode()
guard_kwargs: dict[str, Any] = {}
if self._cache_lock_ttl is not None:
guard_kwargs["lock_ttl"] = self._cache_lock_ttl
if self._cache_retry_delay is not None:
guard_kwargs["retry_delay"] = self._cache_retry_delay
cached = await self._cache_store.guarded_get(
cache_key, compute=_compute, ttl=self._cache_ttl, **guard_kwargs
)
rows_data: list[dict[str, Any]] = json.loads(cached)
return self._hydrate(rows_data)
async with _acquire(pool) as conn:
rows = await conn.fetch(sql, *params)
# Fast path: single-Table select with no JSON columns — pass
# asyncpg Records straight to _from_row(), skipping the
# intermediate dict(row) conversion.
model_class = self._single_table_model()
if model_class is not None and not _json_columns(model_class):
return [ # type: ignore[return-value]
model_class._from_row(row) for row in rows
]
return self._hydrate(self._rows_to_dicts(rows))
def _single_table_model(self) -> type[Table] | None:
"""Return the Table class when selecting a single table."""
if (
len(self._columns) == 1
and isinstance(self._columns[0], type)
and issubclass(self._columns[0], Table)
):
return self._columns[0]
return None
def _rows_to_dicts(self, rows: list[asyncpg.Record]) -> list[dict[str, Any]]:
"""Convert asyncpg Records to plain dicts with JSON deserialization."""
table = None
if self._from_table is not None and not isinstance(
self._from_table, str | SubqueryExpr
):
table = self._from_table
elif (
len(self._columns) == 1
and isinstance(self._columns[0], type)
and issubclass(self._columns[0], Table)
):
table = self._columns[0]
if table is not None:
json_cols = _json_columns(table)
if json_cols:
return [_deserialize_row(table, dict(row)) for row in rows]
return [dict(row) for row in rows]
return [dict(row) for row in rows]
def _is_single_column(self) -> bool:
"""True when selecting exactly one Column descriptor."""
return len(self._columns) == 1 and isinstance(self._columns[0], Column)
def _is_multi_column(self) -> bool:
"""True when selecting multiple Column descriptors."""
return len(self._columns) > 1 and all(
isinstance(c, Column) for c in self._columns
)
def _hydrate(self, rows_data: list[dict[str, Any]]) -> list[T]:
"""Hydrate dicts into model instances or return as-is."""
model_class = self._single_table_model()
if model_class is not None:
return [ # type: ignore[return-value]
model_class._from_row(row) for row in rows_data
]
if self._is_single_column():
return [next(iter(row.values())) for row in rows_data]
if self._is_multi_column():
return [ # type: ignore[return-value]
tuple(row.values()) for row in rows_data
]
return rows_data # type: ignore[return-value]
[docs]
async def first_or_none(self) -> T | None:
"""Execute and return first result or None."""
self._limit_value = 1
results = await self.execute()
return results[0] if results else None
[docs]
async def first(self) -> T:
"""Execute and return first result."""
result = await self.first_or_none()
if result is None:
raise RuntimeError("SELECT query returned no results")
return result
[docs]
async def count(self) -> int:
"""Execute a COUNT(*) query and return the count."""
pool = self._effective_pool()
sql, params = self.build_count()
async with _acquire(pool) as conn:
row = await conn.fetchrow(sql, *params)
return row[0] if row else 0
# =============================================================================
# Set Operation Query (UNION, INTERSECT, EXCEPT)
# =============================================================================
class SetOperationQuery[T]:
"""Combined query from UNION, INTERSECT, or EXCEPT."""
def __init__(
self,
left: SelectQuery[T],
op: str,
right: SelectQuery[Any],
):
self._left = left
self._op = op
self._right = right
self._order_by: list[OrderByClause] = []
self._limit_value: int | None = None
self._offset_value: int | None = None
def order_by(self, column: Column[Any] | str, *, asc: bool = True) -> Self:
"""Add ORDER BY to the combined result."""
self._order_by.append(
OrderByClause(column, SortOrder.ASC if asc else SortOrder.DESC)
)
return self
def limit(self, n: int) -> Self:
"""Add LIMIT to the combined result."""
self._limit_value = n
return self
def offset(self, n: int) -> Self:
"""Add OFFSET to the combined result."""
self._offset_value = n
return self
def build(self) -> tuple[str, list[Any]]:
"""Build the combined SQL query."""
left_sql, left_params = self._left.build()
right_sql, right_params = self._right.build()
params = list(left_params)
offset = len(params)
params.extend(right_params)
renumbered_right = _renumber_params(right_sql, offset)
sql = f"{left_sql} {self._op} {renumbered_right}"
if self._order_by:
order_parts = []
for ob in self._order_by:
if (
isinstance(ob.column, Column)
and ob.column._table_name
and ob.column._field_name
):
order_parts.append(
f"{ob.column._table_name}.{ob.column._field_name} "
f"{ob.direction}"
)
elif isinstance(ob.column, Column) and ob.column._field_name:
order_parts.append(f"{ob.column._field_name} {ob.direction}")
else:
order_parts.append(f"{ob.column} {ob.direction}")
sql += f" ORDER BY {', '.join(order_parts)}"
if self._limit_value is not None:
sql += f" LIMIT {self._limit_value}"
if self._offset_value is not None:
sql += f" OFFSET {self._offset_value}"
return sql, params
# =============================================================================
# INSERT Query with typed returning()
# =============================================================================
class _InsertQueryBase[T: Table]:
"""Base class for INSERT queries with shared implementation."""
def __init__(
self,
pool: asyncpg.Pool | asyncpg.Connection | None,
table: type[T],
*,
router: ReplicaRouter | None = None,
):
self._pool = pool
self._table = table
self._values: dict[str, Any] = {}
self._values_list: list[dict[str, Any]] | None = None
self._on_conflict: OnConflictClause | None = None
self._insert_columns: list[str] | None = None
self._from_select: SelectQuery[Any] | None = None
self._returning: tuple[type[Table] | Column[Any], ...] | None = None
self._router: ReplicaRouter | None = router
def _chunk_values_list(
self,
) -> list[list[dict[str, Any]]] | None:
"""Split ``_values_list`` into chunks that fit within PG's
parameter limit. Returns ``None`` when no chunking is needed.
"""
if self._values_list is None or not self._values_list:
return None
num_cols = len(self._values_list[0])
if num_cols == 0:
return None
total_params = num_cols * len(self._values_list)
if total_params <= _PG_MAX_PARAMS:
return None
rows_per_chunk = _PG_MAX_PARAMS // num_cols
return [
self._values_list[i : i + rows_per_chunk]
for i in range(0, len(self._values_list), rows_per_chunk)
]
def _build(self) -> tuple[str, list[Any]]:
"""Build the SQL query and parameters."""
table_name = self._table.get_table_name()
params: list[Any] = []
# INSERT ... SELECT
if self._from_select is not None:
cols = ", ".join(self._insert_columns or [])
sub_sql, sub_params = self._from_select.build()
params.extend(sub_params)
sql = f"INSERT INTO {table_name} ({cols}) {sub_sql}"
if self._returning:
return_parts = []
for col in self._returning:
if isinstance(col, type) and issubclass(col, Table):
return_parts.append("*")
elif isinstance(col, Column) and col._field_name:
return_parts.append(col._field_name)
sql += f" RETURNING {', '.join(return_parts)}"
return sql, params
if self._values_list is not None:
# Multi-row insert
if not self._values_list:
raise ValueError("values_list() requires at least one row.")
columns = list(self._values_list[0].keys())
all_placeholders = []
for row in self._values_list:
row_ph = []
for col in columns:
params.append(_serialize_value(self._table, col, row[col]))
row_ph.append(f"${len(params)}")
all_placeholders.append(f"({', '.join(row_ph)})")
sql = (
f"INSERT INTO {table_name} ({', '.join(columns)}) "
f"VALUES {', '.join(all_placeholders)}"
)
else:
# Single-row insert
columns = list(self._values.keys())
json_cols = _json_columns(self._table)
if json_cols:
for col, val in self._values.items():
params.append(_serialize_value(self._table, col, val))
else:
params.extend(self._values.values())
placeholders = [f"${i + 1}" for i in range(len(params))]
sql = (
f"INSERT INTO {table_name} ({', '.join(columns)}) "
f"VALUES ({', '.join(placeholders)})"
)
# ON CONFLICT
if self._on_conflict is not None:
target_cols = ", ".join(self._on_conflict.target)
sql += f" ON CONFLICT ({target_cols})"
if self._on_conflict.action == "nothing":
sql += " DO NOTHING"
else:
set_parts = []
for col, val in (self._on_conflict.set_values or {}).items():
params.append(_serialize_value(self._table, col, val))
set_parts.append(f"{col} = ${len(params)}")
sql += f" DO UPDATE SET {', '.join(set_parts)}"
if self._returning:
return_parts = []
for col in self._returning:
if isinstance(col, type) and issubclass(col, Table):
return_parts.append("*")
elif isinstance(col, Column) and col._field_name:
return_parts.append(col._field_name)
sql += f" RETURNING {', '.join(return_parts)}"
return sql, params
def _resolve_target(
self,
target: Column[Any] | tuple[Column[Any], ...],
) -> tuple[str, ...]:
"""Resolve conflict target to column name strings."""
if isinstance(target, Column):
return (target._field_name or "",)
return tuple(f._field_name or "" for f in target)
[docs]
class InsertQuery[T: Table](_InsertQueryBase[T]):
"""INSERT query without RETURNING - execute() returns None."""
[docs]
def values(self, **kwargs: Any) -> InsertQuery[T]:
"""Set values to insert."""
self._values = kwargs
return self
[docs]
def values_list(self, rows: list[dict[str, Any]]) -> InsertBulkQuery[T]:
"""Set multiple rows to insert. Returns a bulk query."""
query: InsertBulkQuery[T] = InsertBulkQuery(
self._pool, self._table, router=self._router
)
query._values_list = rows
query._on_conflict = self._on_conflict
query._insert_columns = self._insert_columns
query._from_select = self._from_select
return query
[docs]
def columns(self, *cols: Column[Any] | str) -> InsertQuery[T]:
"""Set column names for INSERT ... SELECT."""
resolved: list[str] = []
for c in cols:
if isinstance(c, Column) and c._field_name:
resolved.append(c._field_name)
else:
resolved.append(str(c))
self._insert_columns = resolved
return self
[docs]
def from_select(self, query: SelectQuery[Any]) -> InsertQuery[T]:
"""Set the SELECT query for INSERT ... SELECT."""
self._from_select = query
return self
[docs]
def ignore_conflicts(
self,
*,
target: Column[Any] | tuple[Column[Any], ...],
) -> InsertQueryIgnoreConflicts[T]:
"""Add ON CONFLICT DO NOTHING.
Returns a query whose ``returning().execute()`` yields
``T | None`` instead of ``T``, since the conflict may suppress
the insert.
"""
self._on_conflict = OnConflictClause(
target=self._resolve_target(target),
action="nothing",
)
query: InsertQueryIgnoreConflicts[T] = InsertQueryIgnoreConflicts(
self._pool, self._table, router=self._router
)
query._values = self._values
query._values_list = self._values_list
query._on_conflict = self._on_conflict
query._insert_columns = self._insert_columns
query._from_select = self._from_select
return query
[docs]
def upsert(
self,
*,
target: Column[Any] | tuple[Column[Any], ...],
**kwargs: Any,
) -> InsertQuery[T]:
"""Add ON CONFLICT DO UPDATE SET (upsert).
Pass the columns to update as keyword arguments::
.upsert(target=User.email, name="Updated")
"""
self._on_conflict = OnConflictClause(
target=self._resolve_target(target),
action="update",
set_values=kwargs,
)
return self
# fmt: off
@overload
def returning(self, table: type[T], /) -> ReturningOne[T]: ...
@overload
def returning[V](self, c1: Column[V], /) -> ReturningOneScalar[T, V]: ...
@overload
def returning[A, B](self, c1: Column[A], c2: Column[B], /) -> ROT2[T, A, B]: ...
@overload
def returning[A, B, C](self, c1: Column[A], c2: Column[B], c3: Column[C], /) -> ROT3[T, A, B, C]: ...
@overload
def returning[A, B, C, D](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], /) -> ROT4[T, A, B, C, D]: ...
@overload
def returning[A, B, C, D, E](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], /) -> ROT5[T, A, B, C, D, E]: ...
@overload
def returning[A, B, C, D, E, F](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], /) -> ROT6[T, A, B, C, D, E, F]: ...
@overload
def returning[A, B, C, D, E, F, G](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], /) -> ROT7[T, A, B, C, D, E, F, G]: ...
@overload
def returning[A, B, C, D, E, F, G, H](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], /) -> ROT8[T, A, B, C, D, E, F, G, H]: ...
@overload
def returning[A, B, C, D, E, F, G, H, I](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], c9: Column[I], /) -> ROT9[T, A, B, C, D, E, F, G, H, I]: ...
@overload
def returning[A, B, C, D, E, F, G, H, I, J](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], c9: Column[I], c10: Column[J], /) -> ROT10[T, A, B, C, D, E, F, G, H, I, J]: ...
# fmt: on
[docs]
def returning(
self, *columns: Any
) -> ReturningOne[T] | ReturningOneScalar[T, Any] | ReturningOneTuple[T]:
"""Add RETURNING clause."""
if (
len(columns) == 1
and isinstance(columns[0], type)
and issubclass(columns[0], Table)
):
return ReturningOne(self, columns)
if len(columns) == 1:
return ReturningOneScalar(self, columns)
return ReturningOneTuple(self, columns)
[docs]
def build(self) -> tuple[str, list[Any]]:
"""Build the SQL query and parameters."""
return self._build()
[docs]
async def execute(self) -> None:
"""Execute the insert."""
if not self._pool:
raise RuntimeError("No database connection. Call db.connect() first.")
sql, params = self.build()
async with _acquire(self._pool) as conn:
await conn.execute(sql, *params)
if self._router is not None:
self._router.record_write()
# =============================================================================
# Shared RETURNING executors
# =============================================================================
class _ReturningBase[T: Table]:
"""Shared state for all RETURNING query executors."""
__slots__ = ("_parent", "_columns")
def __init__(
self,
parent: Any,
columns: tuple[type[Table] | Column[Any], ...],
) -> None:
self._parent = parent
# Set _returning on the parent so _build() generates the clause.
parent._returning = columns
self._columns = columns
def build(self) -> tuple[str, list[Any]]:
return self._parent._build()
def _is_table_return(self) -> bool:
return (
len(self._columns) == 1
and isinstance(self._columns[0], type)
and issubclass(self._columns[0], Table)
)
def _row_to_model(self, row: Any) -> Any:
return self._parent._table._from_row(_deserialize_row(self._parent._table, row))
def _row_to_scalar(self, row: Any) -> Any:
deserialized = _deserialize_row(self._parent._table, row)
col = self._columns[0]
if isinstance(col, Column) and col._field_name:
return deserialized[col._field_name]
raise RuntimeError("Expected a single Column for scalar return")
def _row_to_tuple(self, row: Any) -> tuple[Any, ...]:
deserialized = _deserialize_row(self._parent._table, row)
return tuple(
deserialized[col._field_name]
for col in self._columns
if isinstance(col, Column) and col._field_name
)
def _record_write(self) -> None:
router = self._parent._router
if router is not None:
router.record_write()
def _check_pool(self) -> Any:
pool = self._parent._pool
if not pool:
raise RuntimeError("No database connection. Call db.connect() first.")
return pool
class ReturningOne[T: Table](_ReturningBase[T]):
"""Single-row RETURNING (INSERT) → ``T``."""
async def execute(self) -> T:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
row = await conn.fetchrow(sql, *params)
if row is None:
raise RuntimeError("INSERT RETURNING returned no rows")
self._record_write()
return self._row_to_model(row)
class ReturningOneScalar[T: Table, V](_ReturningBase[T]):
"""Single-row RETURNING one column (INSERT) → scalar ``V``."""
async def execute(self) -> V:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
row = await conn.fetchrow(sql, *params)
if row is None:
raise RuntimeError("INSERT RETURNING returned no rows")
self._record_write()
return self._row_to_scalar(row)
class ReturningOneTuple[T: Table](_ReturningBase[T]):
"""Single-row RETURNING 2+ columns (INSERT) → ``tuple[Any, ...]``."""
async def execute(self) -> tuple[Any, ...]:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
row = await conn.fetchrow(sql, *params)
if row is None:
raise RuntimeError("INSERT RETURNING returned no rows")
self._record_write()
return self._row_to_tuple(row)
class ReturningOneOptional[T: Table](_ReturningBase[T]):
"""Single-row RETURNING with ON CONFLICT → ``T | None``."""
async def execute(self) -> T | None:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
row = await conn.fetchrow(sql, *params)
if row is None:
return None
self._record_write()
return self._row_to_model(row)
class ReturningOneScalarOptional[T: Table, V](_ReturningBase[T]):
"""Single-row RETURNING one column with ON CONFLICT → ``V | None``."""
async def execute(self) -> V | None:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
row = await conn.fetchrow(sql, *params)
if row is None:
return None
self._record_write()
return self._row_to_scalar(row)
class ReturningOneTupleOptional[T: Table](_ReturningBase[T]):
"""Single-row RETURNING 2+ columns with ON CONFLICT → ``tuple[Any, ...] | None``."""
async def execute(self) -> tuple[Any, ...] | None:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
row = await conn.fetchrow(sql, *params)
if row is None:
return None
self._record_write()
return self._row_to_tuple(row)
class ReturningMany[T: Table](_ReturningBase[T]):
"""Multi-row RETURNING (UPDATE/DELETE) → ``list[T]``."""
async def execute(self) -> list[T]:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
rows = await conn.fetch(sql, *params)
self._record_write()
return [self._row_to_model(row) for row in rows]
class ReturningManyScalar[T: Table, V](_ReturningBase[T]):
"""Multi-row RETURNING one column (UPDATE/DELETE) → ``list[V]``."""
async def execute(self) -> list[V]:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
rows = await conn.fetch(sql, *params)
self._record_write()
return [self._row_to_scalar(row) for row in rows]
class ReturningManyTuple[T: Table](_ReturningBase[T]):
"""Multi-row RETURNING 2+ columns (UPDATE/DELETE) → ``list[tuple[Any, ...]]``."""
async def execute(self) -> list[tuple[Any, ...]]:
pool = self._check_pool()
sql, params = self.build()
async with _acquire(pool) as conn:
rows = await conn.fetch(sql, *params)
self._record_write()
return [self._row_to_tuple(row) for row in rows]
# =============================================================================
# Shared returning mixin for multi-row queries
# =============================================================================
class _ReturningManyMixin[T: Table]:
"""``returning()`` overloads for multi-row queries.
Used by ``InsertBulkQuery``, ``UpdateQuery``, and ``DeleteQuery``.
"""
# fmt: off
@overload
def returning(self, table: type[T], /) -> ReturningMany[T]: ...
@overload
def returning[V](self, c1: Column[V], /) -> ReturningManyScalar[T, V]: ...
@overload
def returning[A, B](self, c1: Column[A], c2: Column[B], /) -> RMT2[T, A, B]: ...
@overload
def returning[A, B, C](self, c1: Column[A], c2: Column[B], c3: Column[C], /) -> RMT3[T, A, B, C]: ...
@overload
def returning[A, B, C, D](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], /) -> RMT4[T, A, B, C, D]: ...
@overload
def returning[A, B, C, D, E](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], /) -> RMT5[T, A, B, C, D, E]: ...
@overload
def returning[A, B, C, D, E, F](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], /) -> RMT6[T, A, B, C, D, E, F]: ...
@overload
def returning[A, B, C, D, E, F, G](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], /) -> RMT7[T, A, B, C, D, E, F, G]: ...
@overload
def returning[A, B, C, D, E, F, G, H](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], /) -> RMT8[T, A, B, C, D, E, F, G, H]: ...
@overload
def returning[A, B, C, D, E, F, G, H, I](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], c9: Column[I], /) -> RMT9[T, A, B, C, D, E, F, G, H, I]: ...
@overload
def returning[A, B, C, D, E, F, G, H, I, J](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], c9: Column[I], c10: Column[J], /) -> RMT10[T, A, B, C, D, E, F, G, H, I, J]: ...
# fmt: on
def returning(
self, *columns: Any
) -> ReturningMany[T] | ReturningManyScalar[T, Any] | ReturningManyTuple[T]:
"""Add RETURNING clause."""
if (
len(columns) == 1
and isinstance(columns[0], type)
and issubclass(columns[0], Table)
):
return ReturningMany(self, columns)
if len(columns) == 1:
return ReturningManyScalar(self, columns)
return ReturningManyTuple(self, columns)
# =============================================================================
# INSERT returning() + ignore_conflicts
# =============================================================================
class InsertQueryIgnoreConflicts[T: Table](_InsertQueryBase[T]):
"""INSERT … ON CONFLICT DO NOTHING query."""
# fmt: off
@overload
def returning(self, table: type[T], /) -> ReturningOneOptional[T]: ...
@overload
def returning[V](self, c1: Column[V], /) -> ReturningOneScalarOptional[T, V]: ...
@overload
def returning[A, B](self, c1: Column[A], c2: Column[B], /) -> ROTO2[T, A, B]: ...
@overload
def returning[A, B, C](self, c1: Column[A], c2: Column[B], c3: Column[C], /) -> ROTO3[T, A, B, C]: ...
@overload
def returning[A, B, C, D](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], /) -> ROTO4[T, A, B, C, D]: ...
@overload
def returning[A, B, C, D, E](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], /) -> ROTO5[T, A, B, C, D, E]: ...
@overload
def returning[A, B, C, D, E, F](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], /) -> ROTO6[T, A, B, C, D, E, F]: ...
@overload
def returning[A, B, C, D, E, F, G](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], /) -> ROTO7[T, A, B, C, D, E, F, G]: ...
@overload
def returning[A, B, C, D, E, F, G, H](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], /) -> ROTO8[T, A, B, C, D, E, F, G, H]: ...
@overload
def returning[A, B, C, D, E, F, G, H, I](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], c9: Column[I], /) -> ROTO9[T, A, B, C, D, E, F, G, H, I]: ...
@overload
def returning[A, B, C, D, E, F, G, H, I, J](self, c1: Column[A], c2: Column[B], c3: Column[C], c4: Column[D], c5: Column[E], c6: Column[F], c7: Column[G], c8: Column[H], c9: Column[I], c10: Column[J], /) -> ROTO10[T, A, B, C, D, E, F, G, H, I, J]: ...
# fmt: on
def returning(
self, *columns: Any
) -> (
ReturningOneOptional[T]
| ReturningOneScalarOptional[T, Any]
| ReturningOneTupleOptional[T]
):
"""Add RETURNING clause (optional result due to ON CONFLICT)."""
if (
len(columns) == 1
and isinstance(columns[0], type)
and issubclass(columns[0], Table)
):
return ReturningOneOptional(self, columns)
if len(columns) == 1:
return ReturningOneScalarOptional(self, columns)
return ReturningOneTupleOptional(self, columns)
def build(self) -> tuple[str, list[Any]]:
return self._build()
async def execute(self) -> None:
"""Execute the insert (no RETURNING)."""
if not self._pool:
raise RuntimeError("No database connection. Call db.connect() first.")
sql, params = self.build()
async with _acquire(self._pool) as conn:
await conn.execute(sql, *params)
if self._router is not None:
self._router.record_write()
class InsertBulkQuery[T: Table](_InsertQueryBase[T], _ReturningManyMixin[T]):
"""INSERT with multiple rows via ``values_list()``.
``returning()`` uses multi-row fetch and returns ``list[T]``,
``list[V]``, or ``list[tuple]``.
"""
def ignore_conflicts(
self,
*,
target: Column[Any] | tuple[Column[Any], ...],
) -> InsertBulkQuery[T]:
"""Add ON CONFLICT DO NOTHING."""
self._on_conflict = OnConflictClause(
target=self._resolve_target(target),
action="nothing",
)
return self
def upsert(
self,
*,
target: Column[Any] | tuple[Column[Any], ...],
**kwargs: Any,
) -> InsertBulkQuery[T]:
"""Add ON CONFLICT DO UPDATE SET (upsert)."""
self._on_conflict = OnConflictClause(
target=self._resolve_target(target),
action="update",
set_values=kwargs,
)
return self
def build(self) -> tuple[str, list[Any]]:
return self._build()
async def execute(self) -> None:
"""Execute the bulk insert.
Rows are automatically split into chunks when the total
parameter count would exceed PostgreSQL's 65 535 limit.
Chunks run inside a single transaction so the operation
is atomic.
"""
if not self._pool:
raise RuntimeError("No database connection. Call db.connect() first.")
chunks = self._chunk_values_list()
if chunks is not None:
saved = self._values_list
async with _acquire(self._pool) as conn:
async with conn.transaction():
for chunk in chunks:
self._values_list = chunk
sql, params = self.build()
await conn.execute(sql, *params)
self._values_list = saved
else:
sql, params = self.build()
async with _acquire(self._pool) as conn:
await conn.execute(sql, *params)
if self._router is not None:
self._router.record_write()
# =============================================================================
# UPDATE Query with typed returning()
# =============================================================================
[docs]
class UpdateQuery[T: Table](_WhereShorthandMixin, _ReturningManyMixin[T]):
"""UPDATE query — execute() returns None, returning() for results."""
[docs]
def __init__(
self,
pool: asyncpg.Pool | asyncpg.Connection | None,
table: type[T],
*,
router: ReplicaRouter | None = None,
):
self._pool = pool
self._table = table
self._set_values: dict[str, Any] = {}
self._where_clause: Expression | None = None
self._returning: tuple[type[Table] | Column[Any], ...] | None = None
self._router: ReplicaRouter | None = router
[docs]
def set(self, **kwargs: Any) -> UpdateQuery[T]:
"""Set values to update."""
self._set_values = kwargs
return self
[docs]
def where(self, cond: Expression) -> UpdateQuery[T]:
"""Add WHERE clause. Multiple calls combine with AND."""
if self._where_clause is not None:
self._where_clause = self._where_clause & cond
else:
self._where_clause = cond
return self
def _build(self) -> tuple[str, list[Any]]:
table_name = self._table.get_table_name()
params: list[Any] = []
set_parts = []
json_cols = _json_columns(self._table)
for col, val in self._set_values.items():
if isinstance(val, RawSQL):
set_parts.append(f"{col} = {val.to_sql(params)}")
else:
if json_cols and col in json_cols:
val = _serialize_value(self._table, col, val)
params.append(val)
set_parts.append(f"{col} = ${len(params)}")
sql = f"UPDATE {table_name} SET {', '.join(set_parts)}"
if self._where_clause:
where_sql = self._where_clause.to_sql(params)
sql += f" WHERE {where_sql}"
if self._returning:
return_parts = []
for col in self._returning:
if isinstance(col, type) and issubclass(col, Table):
return_parts.append("*")
elif isinstance(col, Column) and col._field_name:
return_parts.append(col._field_name)
sql += f" RETURNING {', '.join(return_parts)}"
return sql, params
[docs]
def build(self) -> tuple[str, list[Any]]:
return self._build()
[docs]
async def execute(self) -> None:
"""Execute the update."""
if not self._pool:
raise RuntimeError("No database connection. Call db.connect() first.")
sql, params = self.build()
async with _acquire(self._pool) as conn:
await conn.execute(sql, *params)
if self._router is not None:
self._router.record_write()
# =============================================================================
# DELETE Query
# =============================================================================
[docs]
class DeleteQuery[T: Table](_WhereShorthandMixin, _ReturningManyMixin[T]):
"""DELETE query — execute() returns None, returning() for results."""
[docs]
def __init__(
self,
pool: asyncpg.Pool | asyncpg.Connection | None,
table: type[T],
*,
router: ReplicaRouter | None = None,
):
self._pool = pool
self._table = table
self._where_clause: Expression | None = None
self._returning: tuple[type[Table] | Column[Any], ...] | None = None
self._router: ReplicaRouter | None = router
[docs]
def where(self, cond: Expression) -> DeleteQuery[T]:
"""Add WHERE clause. Multiple calls combine with AND."""
if self._where_clause is not None:
self._where_clause = self._where_clause & cond
else:
self._where_clause = cond
return self
def _build(self) -> tuple[str, list[Any]]:
table_name = self._table.get_table_name()
params: list[Any] = []
sql = f"DELETE FROM {table_name}"
if self._where_clause:
where_sql = self._where_clause.to_sql(params)
sql += f" WHERE {where_sql}"
if self._returning:
return_parts = []
for col in self._returning:
if isinstance(col, type) and issubclass(col, Table):
return_parts.append("*")
elif isinstance(col, Column) and col._field_name:
return_parts.append(col._field_name)
sql += f" RETURNING {', '.join(return_parts)}"
return sql, params
[docs]
def build(self) -> tuple[str, list[Any]]:
return self._build()
[docs]
async def execute(self) -> None:
"""Execute the delete."""
if not self._pool:
raise RuntimeError("No database connection. Call db.connect() first.")
sql, params = self.build()
async with _acquire(self._pool) as conn:
await conn.execute(sql, *params)
if self._router is not None:
self._router.record_write()
# =============================================================================
# Helpers
# =============================================================================
def _serialize_value(table: type[Table], column: str, value: Any) -> Any:
"""Serialize value for database insertion (handles JSONB)."""
if column in _json_columns(table):
if isinstance(value, dict | list):
return json.dumps(value)
return value
_JSON_COLUMNS_CACHE: dict[type, frozenset[str]] = {}
def _json_columns(table: type[Table]) -> frozenset[str]:
"""Get the set of JSON/JSONB column names for a table (cached)."""
if table not in _JSON_COLUMNS_CACHE:
cols = frozenset(
name
for name, col in table.get_columns().items()
if col.sql_type() in ("JSON", "JSONB")
)
_JSON_COLUMNS_CACHE[table] = cols
return _JSON_COLUMNS_CACHE[table]
def _deserialize_row(table: type[Table], row: dict[str, Any]) -> dict[str, Any]:
"""Deserialize row data from database (handles JSONB)."""
json_cols = _json_columns(table)
if not json_cols:
return row # fast path: no JSON columns, skip entirely
result = dict(row)
for col_name in json_cols:
val = result.get(col_name)
if isinstance(val, str):
result[col_name] = json.loads(val)
return result