from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Optional
from ml_toolkit.functions.llm.udtf.constants import (
DEFAULT_MAX_RETRIES,
DEFAULT_MAX_WORKERS,
DEFAULT_TIMEOUT_SEC,
)
# Agent-specific defaults
DEFAULT_AGENT_MODEL = "databricks/claude-3-7-sonnet"
DEFAULT_MAX_TURNS = 25
# Multi-model UDTF return schema
MULTI_MODEL_RETURN_TYPE = (
"model: string, output: string, parameters: string, "
"error: string, raw_output: string"
)
@dataclass
class UDTFConfig:
"""Base configuration for any PySpark UDTF in the toolkit.
Defines shared execution parameters for the ThreadPoolExecutor-based
UDTF pattern. Extend this class for specific UDTF types.
:param max_workers: Number of concurrent workers in the thread pool per partition.
:param timeout_sec: Per-row timeout in seconds.
:param max_retries: Maximum retries on transient errors (429, 5xx).
:param cost_component_name: Cost component for usage tracking, passed as
``metadata={"cost_component_name": ...}`` in OpenAI API calls.
Subclasses may set a default.
"""
max_workers: int = DEFAULT_MAX_WORKERS
timeout_sec: int = DEFAULT_TIMEOUT_SEC
max_retries: int = DEFAULT_MAX_RETRIES
cost_component_name: Optional[str] = None
@dataclass
[docs]
class AgentUDTFConfig(UDTFConfig):
"""Configuration for an OpenAI Agents SDK agent UDTF.
Extends :class:`UDTFConfig` with agent-specific parameters: model,
instructions, tools, and LiteLLM routing.
:param name: Unique name for this agent (used in UDTF registration and telemetry).
:param instructions: System instructions for the agent. Supports Jinja2
templates with variables from the input row (e.g., ``{{database}}``).
:param model: LiteLLM model identifier (e.g., ``"openai/gpt-5"``).
:param tools: List of Python callables to register as agent tools. Each function
should have type annotations and a docstring (used to auto-generate the tool schema).
:param include_sql_tool: Whether to include the built-in read-only SQL execution tool.
:param include_web_search_tool: Whether to include the built-in web search tool.
:param max_turns: Maximum number of agent turns (tool call rounds).
:param warehouse_id: Databricks SQL warehouse ID for the SQL tool.
Resolved from ``DATABRICKS_WAREHOUSE_ID`` env var or Databricks secrets if None.
:param litellm_base_url: LiteLLM proxy base URL.
Resolved from Databricks secrets if None.
:param litellm_api_key: LiteLLM proxy API key.
Resolved from Databricks secrets if None.
:param extra_agent_kwargs: Additional kwargs passed to the ``agents.Agent`` constructor.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Basic configuration with custom tools.
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig
def my_lookup(query: str) -> str:
\"\"\"Look up data in a custom API.\"\"\"
return "result"
config = AgentUDTFConfig(
name="my_agent",
instructions="You are a research assistant.",
model="openai/gpt-5",
tools=[my_lookup], # auto-wrapped with agents.function_tool()
)
.. code-block:: python
:caption: Configuration with built-in tools and Jinja2 instructions.
config = AgentUDTFConfig(
name="corp_insights",
instructions="You analyze data for the {{database}} team.",
model="openai/databricks-claude-3-7-sonnet",
include_sql_tool=True,
include_web_search_tool=True,
)
.. note::
Tools are plain Python callables with type annotations and docstrings.
They are automatically wrapped with ``agents.function_tool()`` at
agent build time. The OpenAI Agents SDK uses the function signature
and docstring to generate the tool schema that the LLM sees.
"""
name: str = "agent"
instructions: str = ""
model: str = DEFAULT_AGENT_MODEL
tools: list[Callable] = field(default_factory=list)
include_sql_tool: bool = False
include_web_search_tool: bool = False
# Agent execution
max_turns: int = DEFAULT_MAX_TURNS
# Infrastructure
warehouse_id: Optional[str] = None
litellm_base_url: Optional[str] = None
litellm_api_key: Optional[str] = None
# Extensibility
extra_agent_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
if not self.name or not self.name.strip():
raise ValueError("AgentUDTFConfig 'name' is required.")
if not self.instructions or not self.instructions.strip():
raise ValueError("AgentUDTFConfig 'instructions' is required.")
def to_dict(self) -> dict[str, Any]:
"""Serialize config to a dict for telemetry/logging (excludes callables)."""
return {
"name": self.name,
"instructions": self.instructions[:200],
"model": self.model,
"tools": [t.__name__ for t in self.tools],
"include_sql_tool": self.include_sql_tool,
"include_web_search_tool": self.include_web_search_tool,
"max_turns": self.max_turns,
"max_workers": self.max_workers,
"timeout_sec": self.timeout_sec,
"max_retries": self.max_retries,
"warehouse_id": self.warehouse_id,
}
@classmethod
def from_yaml(cls, path: str | Path, **overrides) -> "AgentUDTFConfig":
"""Load config from a YAML file.
The YAML should contain keys matching the dataclass fields
(excluding ``tools``, which must be passed via overrides).
:param path: Path to the YAML config file.
:param overrides: Additional kwargs to override YAML values.
"""
import yaml
path = Path(path)
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
# 'tools' can't be loaded from YAML; must be passed as override
data.pop("tools", None)
data.update(overrides)
return cls(**data)
@dataclass
[docs]
class MultiModelUDTFConfig(UDTFConfig):
"""Configuration for a multi-model inference UDTF.
Extends :class:`UDTFConfig` with LiteLLM credentials and optional
structured-output schema. The UDTF sends each input row to multiple
models in parallel and yields one output row per (input, model).
:param litellm_base_url: LiteLLM proxy base URL.
Resolved from ``LITELLM_PROXY_BASE_URL`` env var or Databricks
secrets if None.
:param litellm_api_key: LiteLLM proxy API key.
Resolved from ``LITELLM_API_KEY`` env var or Databricks secrets
if None.
:param response_schema: Optional JSON-schema dict for structured output.
When set, passed as ``response_format`` to the chat completions API.
Examples
^^^^^^^^^^
.. code-block:: python
from ml_toolkit.functions.llm.udtf import (
MultiModelUDTFConfig,
setup_multi_model_udtf,
)
config = MultiModelUDTFConfig(max_workers=20, max_retries=3)
setup_multi_model_udtf(config)
# Input must have: prompt, models (comma-separated), optional system_prompt
spark.sql(\"\"\"
SELECT * FROM multi_model_inference(
TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM my_table)
)
\"\"\")
"""
litellm_url: Optional[str] = None
litellm_api_key: Optional[str] = None
response_schema: Optional[dict] = field(default=None)
def to_dict(self) -> dict[str, Any]:
"""Serialize config to a dict for telemetry/logging."""
return {
"max_workers": self.max_workers,
"timeout_sec": self.timeout_sec,
"max_retries": self.max_retries,
"has_response_schema": self.response_schema is not None,
}
@dataclass
[docs]
class InferenceConfig(UDTFConfig):
"""Configuration for inference UDTF.
Extends :class:`UDTFConfig` with model, prompt, and LiteLLM credentials
for running chat completion inference.
:param model: LiteLLM model identifier (e.g., ``"gpt-5"``).
:param prompt_template: Jinja2 template using ``<<variable>>`` syntax.
If None, the raw input column value is used as the prompt.
:param system_prompt: Optional system prompt for the chat completion.
:param max_output_tokens: Maximum output tokens for model response.
:param input_column: Column name containing the input text.
:param additional_context_columns: Additional columns available in template context.
:param litellm_base_url: LiteLLM proxy base URL.
Resolved from Databricks secrets if None.
:param litellm_api_key: LiteLLM proxy API key.
Resolved from Databricks secrets if None.
:param litellm_model_kwargs: Additional kwargs for the chat completions API.
"""
model: str = ""
prompt_template: Optional[str] = None
system_prompt: Optional[str] = None
max_output_tokens: int = 8192
input_column: str = "input"
additional_context_columns: Optional[list[str]] = field(default_factory=list)
litellm_base_url: Optional[str] = None
litellm_api_key: Optional[str] = None
litellm_model_kwargs: Optional[dict[str, Any]] = None
cost_component_name: Optional[str] = "ai_tagging"
@dataclass
[docs]
class ServingEndpointConfig(UDTFConfig):
"""Configuration for serving endpoint UDTF.
Extends :class:`UDTFConfig` for querying specific served models in
multi-version serving endpoints using the Databricks REST API.
:param endpoint_name: Name of the serving endpoint (optional, can be provided per-row).
:param served_model_name: Served model name to route to (optional, can be provided per-row).
Format: ``"{name}-{endpoint_name}"`` (e.g., ``"v1-my-endpoint"``).
:param timeout_sec: Timeout in seconds for serving endpoint requests.
Default is 900 seconds (15 minutes) to accommodate cold starts from scale-to-zero.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Configuration with default endpoint routing.
from ml_toolkit.functions.llm.udtf import ServingEndpointConfig
config = ServingEndpointConfig(
endpoint_name="my-endpoint",
served_model_name="v1-my-endpoint",
)
.. note::
Both ``endpoint_name`` and ``served_model_name`` can be provided
per-row in the DataFrame, allowing dynamic routing across multiple
endpoints and model versions.
"""
endpoint_name: str = ""
served_model_name: str = ""
timeout_sec: int = 900 # 15 minutes to handle cold starts from scale-to-zero
deployment_name: Optional[str] = None # Auto-captured for serverless
api_token: Optional[str] = None # Auto-captured for serverless
@dataclass
class AiQueryWebSearchConfig(UDTFConfig):
"""Configuration for AI Query with Web Search UDTF.
Extends :class:`UDTFConfig` for running inference with OpenAI's
web search capability via LiteLLM proxy.
:param litellm_base_url: LiteLLM proxy base URL. Resolved from Databricks secrets if None.
:param litellm_api_key: LiteLLM proxy API key. Resolved from Databricks secrets if None.
:param model: Model name to use for inference.
:param response_schema: Optional JSON schema dict for structured output.
:param prompt_delimiter: Delimiter used to separate system and user prompts.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Basic configuration.
from ml_toolkit.functions.llm.udtf import AiQueryWebSearchConfig
config = AiQueryWebSearchConfig(model="gpt-4o")
setup_ai_query_web_search_udtf(config)
spark.sql(\"\"\"
SELECT * FROM ai_query_web_search(
TABLE(SELECT prompt FROM my_table)
)
\"\"\")
"""
litellm_base_url: Optional[str] = None
litellm_api_key: Optional[str] = None
model: str = "gpt-4o"
response_schema: Optional[dict] = None
prompt_delimiter: str = "---SYSTEM_USER_DELIMITER---"
def to_dict(self) -> dict[str, Any]:
"""Serialize config to a dict for telemetry/logging."""
return {
"max_workers": self.max_workers,
"timeout_sec": self.timeout_sec,
"max_retries": self.max_retries,
"model": self.model,
"has_response_schema": self.response_schema is not None,
}