"""Celery queue client."""
from __future__ import annotations
import asyncio
from collections.abc import Sequence
from datetime import timedelta
from typing import Any
from etils import epy
from derp.config import CeleryConfig
from derp.queue.base import QueueClient, Schedule, ScheduleType, TaskState, TaskStatus
from derp.queue.exceptions import QueueNotConnectedError, QueueProviderError
with epy.lazy_imports():
import celery
import celery.result as celery_result
import celery.schedules as celery_schedules
# Map Celery states to TaskState.
_CELERY_STATE_MAP: dict[str, TaskState] = {
"PENDING": TaskState.PENDING,
"STARTED": TaskState.STARTED,
"SUCCESS": TaskState.SUCCESS,
"FAILURE": TaskState.FAILURE,
"REVOKED": TaskState.REVOKED,
"RETRY": TaskState.STARTED,
"RECEIVED": TaskState.PENDING,
}
[docs]
class CeleryQueueClient(QueueClient):
"""Queue client backed by Celery."""
supports_result = True
supports_revoke = True
supports_delay = True
[docs]
def __init__(self, config: CeleryConfig):
self._config = config
self._app: celery.Celery | None = None
self._schedules: list[Schedule] = []
@property
def app(self) -> celery.Celery:
"""Expose the underlying Celery app for worker-side task registration."""
if self._app is None:
self._create_app()
return self._app # type: ignore[return-value]
def _create_app(self) -> None:
self._app = celery.Celery("derp")
self._app.conf.update(
broker_url=self._config.broker_url,
result_backend=self._config.result_backend,
task_serializer=self._config.task_serializer,
result_serializer=self._config.result_serializer,
task_default_queue=self._config.task_default_queue,
)
[docs]
async def connect(self) -> None:
if self._app is not None:
return
self._create_app()
if self._schedules:
self._apply_beat_schedule()
[docs]
async def disconnect(self) -> None:
if self._app is not None:
self._app.close()
self._app = None
[docs]
async def enqueue(
self,
task_name: str,
payload: dict[str, Any] | None = None,
*,
task_id: str | None = None,
queue: str | None = None,
delay: int | timedelta | None = None,
) -> str:
del task_id
if self._app is None:
raise QueueNotConnectedError()
kwargs: dict[str, Any] = {}
if queue is not None:
kwargs["queue"] = queue
if delay is not None:
if isinstance(delay, timedelta):
kwargs["countdown"] = int(delay.total_seconds())
else:
kwargs["countdown"] = delay
try:
result = await asyncio.to_thread(
self._app.send_task,
task_name,
kwargs=payload,
**kwargs,
)
except Exception as exc:
raise QueueProviderError(str(exc) or "Failed to enqueue task") from exc
return str(result.id)
[docs]
async def get_status(self, task_id: str) -> TaskStatus:
if self._app is None:
raise QueueNotConnectedError()
def _fetch_status() -> tuple[str, Any]:
r = celery_result.AsyncResult(task_id, app=self._app)
return r.state, r.result
try:
raw_state, raw_result = await asyncio.to_thread(_fetch_status)
state = _CELERY_STATE_MAP.get(raw_state, TaskState.UNKNOWN)
task_result = raw_result if state == TaskState.SUCCESS else None
error = str(raw_result) if state == TaskState.FAILURE else None
except Exception as exc:
raise QueueProviderError(str(exc) or "Failed to get task status") from exc
return TaskStatus(
task_id=task_id,
state=state,
result=task_result,
error=error,
)
[docs]
def register_schedules(self, schedules: Sequence[Schedule]) -> None:
"""Register recurring schedules with Celery Beat."""
self._schedules = list(schedules)
if self._app is not None:
self._apply_beat_schedule()
[docs]
def get_schedules(self) -> list[Schedule]:
"""Return the currently registered schedules."""
return self._schedules
def _apply_beat_schedule(self) -> None:
"""Write schedules into Celery's beat_schedule config."""
assert self._app is not None
beat_schedule: dict[str, Any] = {}
for s in self._schedules:
entry: dict[str, Any] = {"task": s.task}
if s.type == ScheduleType.CRON and s.cron:
parts = s.cron.split()
entry["schedule"] = celery_schedules.crontab(
minute=parts[0] if len(parts) > 0 else "*",
hour=parts[1] if len(parts) > 1 else "*",
day_of_month=parts[2] if len(parts) > 2 else "*",
month_of_year=parts[3] if len(parts) > 3 else "*",
day_of_week=parts[4] if len(parts) > 4 else "*",
)
elif s.type == ScheduleType.INTERVAL and s.interval:
entry["schedule"] = s.interval.total_seconds()
if s.payload:
entry["kwargs"] = s.payload
if s.queue:
entry["options"] = {"queue": s.queue}
beat_schedule[s.name] = entry
self._app.conf.beat_schedule = beat_schedule