Source code for ml_toolkit.functions.llm.udtf.interface

from typing import Literal, Optional, Union

from pyspark.sql import DataFrame
from yipit_databricks_client.helpers.telemetry import track_usage

from ml_toolkit.functions.llm.udtf.base import run_udtf_batch
from ml_toolkit.functions.llm.udtf.config import (
    AgentUDTFConfig,
    InferenceConfig,
    MultiModelUDTFConfig,
    UDTFConfig,
)
from ml_toolkit.functions.llm.udtf.templates.agentic import (
    setup_agent_udtf as _setup_agent_udtf,
)
from ml_toolkit.functions.llm.udtf.templates.inference import (
    setup_inference_udtf as _setup_inference_udtf,
)
from ml_toolkit.functions.llm.udtf.templates.multi_model import (
    setup_multi_model_udtf as _setup_multi_model_udtf,
)
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.validation import for_loop_guardrail


@track_usage
[docs] def setup_udtf( config: UDTFConfig, name: str = None, prompt_column: str = "prompt", ) -> type: """Register a PySpark UDTF based on the config type. Dispatches to the appropriate UDTF implementation based on the config class. After registration, the UDTF can be called from Spark SQL. Supported config types: - :class:`AgentUDTFConfig` — Runs an OpenAI Agents SDK agent per row. - :class:`MultiModelUDTFConfig` — Sends each prompt to multiple models in parallel, yielding one row per (input, model). :param config: A :class:`UDTFConfig` subclass instance. :param name: UDTF name for SQL registration. Defaults vary by type: ``run_{config.name}`` for agents, ``"multi_model_inference"`` for multi-model. :param prompt_column: Column name containing the user prompt. Defaults to ``"prompt"``. :returns: The UDTF class (also registered in SparkSession). :raises TypeError: If *config* is not a supported subclass. Examples ^^^^^^^^^^ .. code-block:: python from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf config = AgentUDTFConfig( name="my_agent", instructions="You are a helpful assistant.", model="openai/databricks-claude-3-7-sonnet", ) setup_udtf(config) spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))") .. code-block:: python from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf config = MultiModelUDTFConfig(max_workers=20) setup_udtf(config) spark.sql(\"\"\" SELECT * FROM multi_model_inference( TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM prompts) ) \"\"\") """ logger = get_logger() if isinstance(config, AgentUDTFConfig): udtf_name = name or f"run_{config.name}" logger.info(f"Setting up agent UDTF: {udtf_name}") return _setup_agent_udtf( config=config, name=udtf_name, prompt_column=prompt_column ) if isinstance(config, MultiModelUDTFConfig): udtf_name = name or "multi_model_inference" logger.info(f"Setting up multi-model UDTF: {udtf_name}") return _setup_multi_model_udtf(config=config, name=udtf_name) if isinstance(config, InferenceConfig): udtf_name = name or "_inference" logger.info(f"Setting up inference UDTF: {udtf_name}") return _setup_inference_udtf( config=config, name=udtf_name, prompt_column=prompt_column ) raise TypeError( f"Unsupported config type: {type(config).__name__}. " "Use AgentUDTFConfig, MultiModelUDTFConfig, or InferenceConfig." )
@track_usage @for_loop_guardrail(min_interval_seconds=10)
[docs] def run_agent_batch( data_source: Union[DataFrame, str], agent_config: AgentUDTFConfig, prompt_column: str = "prompt", output_table_name: Optional[str] = None, dry_run: bool = None, table_operation: Literal["overwrite", "append"] = "overwrite", cost_component_name: str = None, ) -> dict: """Run an agent against every row in a DataFrame or table. This is the high-level batch API (analogous to :func:`run_llm_batch`). It registers a temporary UDTF, runs it against all input rows, and writes results to a Delta table or returns them as a DataFrame. There are two operation modes: 1. ``dry_run=True``: Returns results as a DataFrame without writing. Only works with fewer than 1,000 rows. 2. ``dry_run=False``: Writes results to ``output_table_name``. Can ``overwrite`` or ``append`` (set via ``table_operation``). .. attention:: You **must** pass a ``cost_component_name``, otherwise this function will raise an exception. :param data_source: Input DataFrame or fully-qualified Delta table name. :param agent_config: :class:`AgentUDTFConfig` defining the agent. :param prompt_column: Column name containing the user prompt. Defaults to ``"prompt"``. :param output_table_name: Fully-qualified output Delta table name (required for batch mode). :param dry_run: If True, returns results as DataFrame. If None, auto-detected based on row count. :param table_operation: ``"overwrite"`` or ``"append"`` for the output table. :param cost_component_name: Required cost component for telemetry. :returns: Dict with ``source_uuid``, ``model``, ``df``, ``output_table_name``, ``cost_component_name``, and ``error_message``. :raises ValueError: If ``cost_component_name`` is missing or ``output_table_name`` missing in batch mode. :raises MLOpsToolkitTooManyRowsForInteractiveUsage: If ``dry_run=True`` with too many rows. Examples ^^^^^^^^^^ .. code-block:: python from ml_toolkit.functions.llm import AgentConfig, run_agent_batch config = AgentConfig( name="corp_insights", instructions="You analyze corporate data for the {{database}} team.", model="openai/databricks-claude-3-7-sonnet", include_sql_tool=True, include_web_search_tool=True, ) result = run_agent_batch( data_source="catalog.schema.input_prompts", agent_config=config, prompt_column="prompt", output_table_name="catalog.schema.agent_output", cost_component_name="my_team", ) display(result["df"]) """ result = run_udtf_batch( data_source=data_source, udtf_setup_fn=_setup_agent_udtf, config=agent_config, prompt_column=prompt_column, output_table_name=output_table_name, dry_run=dry_run, table_operation=table_operation, cost_component_name=cost_component_name, model_name=agent_config.model, ) # Add agent-specific metadata result["agent_name"] = agent_config.name return result