"""AI client wrapping AsyncOpenAI."""
from __future__ import annotations
import asyncio
import json as _json
from collections.abc import AsyncIterator, Sequence
from typing import Any
import httpx
from etils import epy
from derp.ai.exceptions import (
FalJobAlreadyCompletedError,
FalJobFailedError,
FalJobNotFoundError,
FalMissingCredentialsError,
ModalNotConnectedError,
)
from derp.ai.models import (
CancelResult,
CancelState,
ChatChunk,
ChatResponse,
JobState,
JobStatus,
Tool,
ToolCall,
ToolEventType,
Usage,
_build_tool_map,
_parse_tool_call,
)
from derp.config import AIConfig
with epy.lazy_imports():
import fal_client
import openai
[docs]
class AIClient:
"""Async AI client wrapping several providers.
Example::
config = AIConfig(api_key="...")
ai = AIClient(config)
response = await ai.chat(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Hello"}],
)
"""
[docs]
def __init__(self, config: AIConfig):
self._config = config
self.client: openai.AsyncOpenAI = openai.AsyncOpenAI(
api_key=config.api_key,
base_url=config.base_url,
)
if config.fal_api_key is not None:
self._fal_client: fal_client.AsyncClient | None = fal_client.AsyncClient(
key=config.fal_api_key,
)
else:
self._fal_client = None
self._modal_client: httpx.AsyncClient | None = None
[docs]
async def connect(self) -> None:
if self._config.modal is not None:
self._modal_client = httpx.AsyncClient(
headers={
"Modal-Key": self._config.modal.token_id,
"Modal-Secret": self._config.modal.token_secret,
},
base_url=self._config.modal.endpoint_url or "",
)
[docs]
async def disconnect(self) -> None:
if self._modal_client is not None:
await self._modal_client.aclose()
self._modal_client = None
[docs]
async def chat(
self,
model: str,
*,
messages: list[dict[str, Any]],
tools: Sequence[type[Tool]] = (),
**kwargs: Any,
) -> ChatResponse:
"""Create a chat completion.
Args:
model: Model ID to use.
messages: List of message dicts.
tools: Optional list of Tool subclasses.
**kwargs: Additional arguments forwarded to the API.
Returns:
ChatResponse with content, usage, and protocol adapters.
"""
name_map: dict[str, type[Tool]] | None = None
if tools:
schemas, name_map = _build_tool_map(tools)
kwargs["tools"] = schemas
completion = await self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs,
)
choice = completion.choices[0]
usage = None
if completion.usage:
usage = Usage(
prompt_tokens=completion.usage.prompt_tokens,
completion_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
parsed_tool_calls: list[ToolCall] = []
if choice.message.tool_calls:
parsed_tool_calls = [
_parse_tool_call(tc, name_map) for tc in choice.message.tool_calls
]
return ChatResponse(
content=choice.message.content or "",
role=choice.message.role,
model=completion.model,
usage=usage,
finish_reason=choice.finish_reason or "stop",
tool_calls=parsed_tool_calls,
)
[docs]
async def stream_chat(
self,
model: str,
*,
messages: list[dict[str, Any]],
tools: Sequence[type[Tool]] = (),
**kwargs: Any,
) -> AsyncIterator[ChatChunk]:
"""Create a streaming chat completion.
Args:
model: Model ID to use.
messages: List of message dicts.
tools: Optional list of Tool subclasses.
**kwargs: Additional arguments forwarded to the API.
Yields:
ChatChunk for each text delta. The final chunk includes
parsed tool_calls when the model invokes tools.
"""
name_map: dict[str, type[Tool]] | None = None
if tools:
schemas, name_map = _build_tool_map(tools)
kwargs["tools"] = schemas
kwargs.setdefault("stream_options", {"include_usage": True})
stream = await self.client.chat.completions.create(
model=model,
messages=messages,
stream=True,
**kwargs,
)
first = True
finish_reason: str | None = None
finish_model: str = model
usage: Usage | None = None
# Accumulate streamed tool call fragments: index -> [id, name, args]
tc_acc: dict[int, list[Any]] = {}
async for chunk in stream:
if chunk.choices:
choice = chunk.choices[0]
if choice.delta.content:
yield ChatChunk(
delta=choice.delta.content,
role=choice.delta.role or "assistant",
model=getattr(chunk, "model", model),
is_first=first,
)
first = False
# Accumulate tool call deltas
if choice.delta.tool_calls:
for tc_delta in choice.delta.tool_calls:
idx = tc_delta.index
if idx not in tc_acc:
tc_acc[idx] = [
tc_delta.id or "",
(
tc_delta.function.name or ""
if tc_delta.function
else ""
),
"",
]
else:
if tc_delta.id:
tc_acc[idx][0] = tc_delta.id
if tc_delta.function and tc_delta.function.name:
tc_acc[idx][1] = tc_delta.function.name
if tc_delta.function and tc_delta.function.arguments:
tc_acc[idx][2] += tc_delta.function.arguments
if choice.finish_reason:
finish_reason = choice.finish_reason
finish_model = getattr(chunk, "model", model)
if chunk.usage:
usage = Usage(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
)
# Build parsed tool calls from accumulated fragments
parsed_tool_calls: list[ToolCall] = []
for idx in sorted(tc_acc):
tc_id, fn_name, raw_args = tc_acc[idx]
parsed: Any = None
if name_map and fn_name in name_map:
parsed = name_map[fn_name].model_validate_json(raw_args)
parsed_tool_calls.append(
ToolCall(
id=tc_id,
function_name=fn_name,
arguments=raw_args,
args=parsed,
)
)
if finish_reason:
yield ChatChunk(
delta="",
model=finish_model,
finish_reason=finish_reason,
usage=usage,
tool_calls=parsed_tool_calls,
is_last=True,
)
[docs]
async def stream_agent(
self,
model: str,
*,
messages: list[dict[str, Any]],
tools: Sequence[type[Tool]] = (),
tool_args: Sequence[Any] = (),
max_turns: int = 10,
**kwargs: Any,
) -> AsyncIterator[ChatChunk]:
"""Stream a chat completion loop, auto-executing tool calls.
Streams via :meth:`stream_chat` in a loop. Text deltas are yielded
as they arrive. When the model returns tool calls, each tool is
executed via its :meth:`~Tool.run` method, results are appended
as tool messages, and the next round starts automatically.
The loop continues until the model returns a text response
(no tool calls) or *max_turns* is reached.
Args:
model: Model ID to use.
messages: List of message dicts (mutated in place).
tools: Tool subclasses available to the model.
tool_args: Extra positional args forwarded to each
:meth:`Tool.run` call (e.g. request-scoped state).
max_turns: Maximum number of tool-call round-trips.
**kwargs: Additional arguments forwarded to the API.
Yields:
ChatChunk for each text delta across all turns.
"""
for _ in range(max_turns):
last_chunk: ChatChunk | None = None
async for chunk in self.stream_chat(
model=model, messages=messages, tools=tools, **kwargs
):
last_chunk = chunk
yield chunk
if last_chunk is None or not last_chunk.tool_calls:
return
# Append the assistant message with tool calls
messages.append(
{
"role": "assistant",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function_name,
"arguments": tc.arguments,
},
}
for tc in last_chunk.tool_calls
],
}
)
# Execute each tool, yield lifecycle events, append results
for tc in last_chunk.tool_calls:
args = _json.loads(tc.arguments) if tc.arguments else {}
yield ChatChunk(
delta="",
tool_event=ToolEventType.INPUT_START,
tool_call_id=tc.id,
tool_name=tc.function_name,
)
yield ChatChunk(
delta="",
tool_event=ToolEventType.INPUT_AVAILABLE,
tool_call_id=tc.id,
tool_name=tc.function_name,
tool_input=args,
)
result = await tc.run(*tool_args)
yield ChatChunk(
delta="",
tool_event=ToolEventType.OUTPUT_AVAILABLE,
tool_call_id=tc.id,
tool_name=tc.function_name,
tool_output=result,
)
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": (
result if isinstance(result, str) else _json.dumps(result)
),
}
)
[docs]
async def fal_submit(
self, app: str, *, inputs: dict[str, Any], start_timeout: float = 10.0
) -> str:
"""Submit a job to a Fal application.
Args:
app: Fal application name.
inputs: Inputs to the model.
start_timeout: Start timeout in seconds. Default is 10 seconds.
Returns:
Request ID of the submitted task.
"""
if self._fal_client is None:
raise FalMissingCredentialsError()
result = await self._fal_client.submit(
app,
arguments=inputs,
start_timeout=start_timeout,
)
return result.request_id
[docs]
async def fal_call(
self,
app: str,
*,
inputs: dict[str, Any],
poll_interval: float = 2.0,
timeout: float = 60.0,
start_timeout: float = 10.0,
) -> dict[str, Any]:
"""Submit a fal job and wait for the result.
Convenience method combining :meth:`fal_submit`, :meth:`fal_poll`,
and :meth:`fal_get` into a single call.
Args:
app: Fal application name.
inputs: Inputs to the model.
poll_interval: Seconds between status polls. Default is 2.
timeout: Maximum seconds to wait. Default is 60.
start_timeout: Start timeout in seconds. Default is 10.
Returns:
Result dict from the completed job.
Raises:
FalJobFailedError: If the job fails.
TimeoutError: If the job does not complete within *timeout*.
"""
request_id = await self.fal_submit(
app, inputs=inputs, start_timeout=start_timeout
)
async def _poll() -> dict[str, Any]:
while True:
status = await self.fal_poll(app, request_id)
if status.is_completed:
return await self.fal_get(app, request_id)
if status.is_failed:
raise FalJobFailedError(status.error or "Fal job failed")
await asyncio.sleep(poll_interval)
return await asyncio.wait_for(_poll(), timeout=timeout)
[docs]
async def fal_poll(self, app: str, request_id: str) -> JobStatus:
"""Poll the status of a fal job.
Args:
app: Fal application name.
request_id: Request ID returned by fal_submit.
Returns:
JobStatus with the current state of the job.
"""
if self._fal_client is None:
raise FalMissingCredentialsError()
handle = self._fal_client.get_handle(app, request_id)
status = await handle.status()
if isinstance(status, fal_client.Queued):
return JobStatus(state=JobState.QUEUED, position=status.position)
elif isinstance(status, fal_client.InProgress):
return JobStatus(state=JobState.IN_PROGRESS, logs=status.logs)
elif isinstance(status, fal_client.Completed):
if status.error:
state = JobState.FAILED
else:
state = JobState.COMPLETED
return JobStatus(
state=state,
logs=status.logs,
metrics=status.metrics,
error=status.error,
error_type=status.error_type,
)
return JobStatus(state=JobState.UNKNOWN)
[docs]
async def fal_get(self, app: str, request_id: str) -> dict[str, Any]:
"""Get the result of a fal job.
Args:
app: Fal application name.
request_id: Request ID returned by fal_submit.
Returns:
Result of the job as a dict.
"""
if self._fal_client is None:
raise FalMissingCredentialsError()
handle = self._fal_client.get_handle(app, request_id)
result = await handle.get()
return result
[docs]
async def fal_cancel(self, app: str, request_id: str) -> CancelResult:
"""Cancel a fal job.
Args:
app: Fal application name.
request_id: Request ID returned by fal_submit.
Returns:
CancelResult with the cancellation state and job state.
Raises:
FalJobAlreadyCompletedError: If the job already completed.
FalJobNotFoundError: If the job was not found.
"""
if self._fal_client is None:
raise FalMissingCredentialsError()
handle = self._fal_client.get_handle(app, request_id)
status = await handle.status()
if isinstance(status, fal_client.Queued):
job_state = JobState.QUEUED
elif isinstance(status, fal_client.InProgress):
job_state = JobState.IN_PROGRESS
elif isinstance(status, fal_client.Completed):
if status.error:
job_state = JobState.FAILED
else:
job_state = JobState.COMPLETED
else:
job_state = JobState.UNKNOWN
try:
await handle.cancel()
except fal_client.FalClientHTTPError as exc:
if exc.status_code == 400:
raise FalJobAlreadyCompletedError() from exc
if exc.status_code == 404:
raise FalJobNotFoundError() from exc
raise exc
return CancelResult(
state=CancelState.CANCELLATION_REQUESTED,
job_state=job_state,
)
[docs]
async def modal_call(
self, endpoint: str, *, inputs: dict[str, Any], timeout: float = 30.0
) -> dict[str, Any]:
"""Call a Modal endpoint.
Args:
endpoint: Modal endpoint name.
inputs: Inputs to the endpoint.
timeout: Timeout in seconds. Default is 30 seconds.
Returns:
Result of the endpoint as a dict.
"""
if self._modal_client is None:
raise ModalNotConnectedError()
response = await self._modal_client.post(
endpoint,
json=inputs,
timeout=timeout,
)
return response.json()