Source code for ml_toolkit.functions.llm.udtf.config

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Optional

from ml_toolkit.functions.llm.udtf.constants import (
    DEFAULT_MAX_RETRIES,
    DEFAULT_MAX_WORKERS,
    DEFAULT_TIMEOUT_SEC,
)

# Agent-specific defaults
DEFAULT_AGENT_MODEL = "databricks/claude-3-7-sonnet"
DEFAULT_MAX_TURNS = 25

# Multi-model UDTF return schema
MULTI_MODEL_RETURN_TYPE = (
    "model: string, output: string, parameters: string, "
    "error: string, raw_output: string"
)


@dataclass
class UDTFConfig:
    """Base configuration for any PySpark UDTF in the toolkit.

    Defines shared execution parameters for the ThreadPoolExecutor-based
    UDTF pattern. Extend this class for specific UDTF types.

    :param max_workers: Number of concurrent workers in the thread pool per partition.
    :param timeout_sec: Per-row timeout in seconds.
    :param max_retries: Maximum retries on transient errors (429, 5xx).
    :param cost_component_name: Cost component for usage tracking, passed as
        ``metadata={"cost_component_name": ...}`` in OpenAI API calls.
        Subclasses may set a default.
    """

    max_workers: int = DEFAULT_MAX_WORKERS
    timeout_sec: int = DEFAULT_TIMEOUT_SEC
    max_retries: int = DEFAULT_MAX_RETRIES
    cost_component_name: Optional[str] = None


@dataclass
[docs] class AgentUDTFConfig(UDTFConfig): """Configuration for an OpenAI Agents SDK agent UDTF. Extends :class:`UDTFConfig` with agent-specific parameters: model, instructions, tools, and LiteLLM routing. :param name: Unique name for this agent (used in UDTF registration and telemetry). :param instructions: System instructions for the agent. Supports Jinja2 templates with variables from the input row (e.g., ``{{database}}``). :param model: LiteLLM model identifier (e.g., ``"openai/gpt-5"``). :param tools: List of Python callables to register as agent tools. Each function should have type annotations and a docstring (used to auto-generate the tool schema). :param include_sql_tool: Whether to include the built-in read-only SQL execution tool. :param include_web_search_tool: Whether to include the built-in web search tool. :param max_turns: Maximum number of agent turns (tool call rounds). :param warehouse_id: Databricks SQL warehouse ID for the SQL tool. Resolved from ``DATABRICKS_WAREHOUSE_ID`` env var or Databricks secrets if None. :param litellm_base_url: LiteLLM proxy base URL. Resolved from Databricks secrets if None. :param litellm_api_key: LiteLLM proxy API key. Resolved from Databricks secrets if None. :param extra_agent_kwargs: Additional kwargs passed to the ``agents.Agent`` constructor. Examples ^^^^^^^^^^ .. code-block:: python :caption: Basic configuration with custom tools. from ml_toolkit.functions.llm.udtf import AgentUDTFConfig def my_lookup(query: str) -> str: \"\"\"Look up data in a custom API.\"\"\" return "result" config = AgentUDTFConfig( name="my_agent", instructions="You are a research assistant.", model="openai/gpt-5", tools=[my_lookup], # auto-wrapped with agents.function_tool() ) .. code-block:: python :caption: Configuration with built-in tools and Jinja2 instructions. config = AgentUDTFConfig( name="corp_insights", instructions="You analyze data for the {{database}} team.", model="openai/databricks-claude-3-7-sonnet", include_sql_tool=True, include_web_search_tool=True, ) .. note:: Tools are plain Python callables with type annotations and docstrings. They are automatically wrapped with ``agents.function_tool()`` at agent build time. The OpenAI Agents SDK uses the function signature and docstring to generate the tool schema that the LLM sees. """ name: str = "agent" instructions: str = "" model: str = DEFAULT_AGENT_MODEL tools: list[Callable] = field(default_factory=list) include_sql_tool: bool = False include_web_search_tool: bool = False # Agent execution max_turns: int = DEFAULT_MAX_TURNS # Infrastructure warehouse_id: Optional[str] = None litellm_base_url: Optional[str] = None litellm_api_key: Optional[str] = None # Extensibility extra_agent_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: if not self.name or not self.name.strip(): raise ValueError("AgentUDTFConfig 'name' is required.") if not self.instructions or not self.instructions.strip(): raise ValueError("AgentUDTFConfig 'instructions' is required.") def to_dict(self) -> dict[str, Any]: """Serialize config to a dict for telemetry/logging (excludes callables).""" return { "name": self.name, "instructions": self.instructions[:200], "model": self.model, "tools": [t.__name__ for t in self.tools], "include_sql_tool": self.include_sql_tool, "include_web_search_tool": self.include_web_search_tool, "max_turns": self.max_turns, "max_workers": self.max_workers, "timeout_sec": self.timeout_sec, "max_retries": self.max_retries, "warehouse_id": self.warehouse_id, } @classmethod def from_yaml(cls, path: str | Path, **overrides) -> "AgentUDTFConfig": """Load config from a YAML file. The YAML should contain keys matching the dataclass fields (excluding ``tools``, which must be passed via overrides). :param path: Path to the YAML config file. :param overrides: Additional kwargs to override YAML values. """ import yaml path = Path(path) with path.open("r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} # 'tools' can't be loaded from YAML; must be passed as override data.pop("tools", None) data.update(overrides) return cls(**data)
@dataclass
[docs] class MultiModelUDTFConfig(UDTFConfig): """Configuration for a multi-model inference UDTF. Extends :class:`UDTFConfig` with LiteLLM credentials and optional structured-output schema. The UDTF sends each input row to multiple models in parallel and yields one output row per (input, model). :param litellm_base_url: LiteLLM proxy base URL. Resolved from ``LITELLM_PROXY_BASE_URL`` env var or Databricks secrets if None. :param litellm_api_key: LiteLLM proxy API key. Resolved from ``LITELLM_API_KEY`` env var or Databricks secrets if None. :param response_schema: Optional JSON-schema dict for structured output. When set, passed as ``response_format`` to the chat completions API. Examples ^^^^^^^^^^ .. code-block:: python from ml_toolkit.functions.llm.udtf import ( MultiModelUDTFConfig, setup_multi_model_udtf, ) config = MultiModelUDTFConfig(max_workers=20, max_retries=3) setup_multi_model_udtf(config) # Input must have: prompt, models (comma-separated), optional system_prompt spark.sql(\"\"\" SELECT * FROM multi_model_inference( TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM my_table) ) \"\"\") """ litellm_url: Optional[str] = None litellm_api_key: Optional[str] = None response_schema: Optional[dict] = field(default=None) def to_dict(self) -> dict[str, Any]: """Serialize config to a dict for telemetry/logging.""" return { "max_workers": self.max_workers, "timeout_sec": self.timeout_sec, "max_retries": self.max_retries, "has_response_schema": self.response_schema is not None, }
@dataclass
[docs] class InferenceConfig(UDTFConfig): """Configuration for inference UDTF. Extends :class:`UDTFConfig` with model, prompt, and LiteLLM credentials for running chat completion inference. :param model: LiteLLM model identifier (e.g., ``"gpt-5"``). :param prompt_template: Jinja2 template using ``<<variable>>`` syntax. If None, the raw input column value is used as the prompt. :param system_prompt: Optional system prompt for the chat completion. :param max_output_tokens: Maximum output tokens for model response. :param input_column: Column name containing the input text. :param additional_context_columns: Additional columns available in template context. :param litellm_base_url: LiteLLM proxy base URL. Resolved from Databricks secrets if None. :param litellm_api_key: LiteLLM proxy API key. Resolved from Databricks secrets if None. :param litellm_model_kwargs: Additional kwargs for the chat completions API. """ model: str = "" prompt_template: Optional[str] = None system_prompt: Optional[str] = None max_output_tokens: int = 8192 input_column: str = "input" additional_context_columns: Optional[list[str]] = field(default_factory=list) litellm_base_url: Optional[str] = None litellm_api_key: Optional[str] = None litellm_model_kwargs: Optional[dict[str, Any]] = None cost_component_name: Optional[str] = "ai_tagging"
@dataclass
[docs] class ServingEndpointConfig(UDTFConfig): """Configuration for serving endpoint UDTF. Extends :class:`UDTFConfig` for querying specific served models in multi-version serving endpoints using the Databricks REST API. :param endpoint_name: Name of the serving endpoint (optional, can be provided per-row). :param served_model_name: Served model name to route to (optional, can be provided per-row). Format: ``"{name}-{endpoint_name}"`` (e.g., ``"v1-my-endpoint"``). :param timeout_sec: Timeout in seconds for serving endpoint requests. Default is 900 seconds (15 minutes) to accommodate cold starts from scale-to-zero. Examples ^^^^^^^^^^ .. code-block:: python :caption: Configuration with default endpoint routing. from ml_toolkit.functions.llm.udtf import ServingEndpointConfig config = ServingEndpointConfig( endpoint_name="my-endpoint", served_model_name="v1-my-endpoint", ) .. note:: Both ``endpoint_name`` and ``served_model_name`` can be provided per-row in the DataFrame, allowing dynamic routing across multiple endpoints and model versions. """ endpoint_name: str = "" served_model_name: str = "" timeout_sec: int = 900 # 15 minutes to handle cold starts from scale-to-zero deployment_name: Optional[str] = None # Auto-captured for serverless api_token: Optional[str] = None # Auto-captured for serverless
@dataclass class AiQueryWebSearchConfig(UDTFConfig): """Configuration for AI Query with Web Search UDTF. Extends :class:`UDTFConfig` for running inference with OpenAI's web search capability via LiteLLM proxy. :param litellm_base_url: LiteLLM proxy base URL. Resolved from Databricks secrets if None. :param litellm_api_key: LiteLLM proxy API key. Resolved from Databricks secrets if None. :param model: Model name to use for inference. :param response_schema: Optional JSON schema dict for structured output. :param prompt_delimiter: Delimiter used to separate system and user prompts. Examples ^^^^^^^^^^ .. code-block:: python :caption: Basic configuration. from ml_toolkit.functions.llm.udtf import AiQueryWebSearchConfig config = AiQueryWebSearchConfig(model="gpt-4o") setup_ai_query_web_search_udtf(config) spark.sql(\"\"\" SELECT * FROM ai_query_web_search( TABLE(SELECT prompt FROM my_table) ) \"\"\") """ litellm_base_url: Optional[str] = None litellm_api_key: Optional[str] = None model: str = "gpt-4o" response_schema: Optional[dict] = None prompt_delimiter: str = "---SYSTEM_USER_DELIMITER---" def to_dict(self) -> dict[str, Any]: """Serialize config to a dict for telemetry/logging.""" return { "max_workers": self.max_workers, "timeout_sec": self.timeout_sec, "max_retries": self.max_retries, "model": self.model, "has_response_schema": self.response_schema is not None, }