from typing import Literal, Optional, Union
from pyspark.sql import DataFrame
from yipit_databricks_client.helpers.telemetry import track_usage
from ml_toolkit.functions.llm.udtf.base import run_udtf_batch
from ml_toolkit.functions.llm.udtf.config import (
AgentUDTFConfig,
InferenceConfig,
MultiModelUDTFConfig,
UDTFConfig,
)
from ml_toolkit.functions.llm.udtf.templates.agentic import (
setup_agent_udtf as _setup_agent_udtf,
)
from ml_toolkit.functions.llm.udtf.templates.inference import (
setup_inference_udtf as _setup_inference_udtf,
)
from ml_toolkit.functions.llm.udtf.templates.multi_model import (
setup_multi_model_udtf as _setup_multi_model_udtf,
)
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.validation import for_loop_guardrail
@track_usage
[docs]
def setup_udtf(
config: UDTFConfig,
name: str = None,
prompt_column: str = "prompt",
) -> type:
"""Register a PySpark UDTF based on the config type.
Dispatches to the appropriate UDTF implementation based on the config
class. After registration, the UDTF can be called from Spark SQL.
Supported config types:
- :class:`AgentUDTFConfig` — Runs an OpenAI Agents SDK agent per row.
- :class:`MultiModelUDTFConfig` — Sends each prompt to multiple models
in parallel, yielding one row per (input, model).
:param config: A :class:`UDTFConfig` subclass instance.
:param name: UDTF name for SQL registration. Defaults vary by type:
``run_{config.name}`` for agents, ``"multi_model_inference"`` for multi-model.
:param prompt_column: Column name containing the user prompt. Defaults to ``"prompt"``.
:returns: The UDTF class (also registered in SparkSession).
:raises TypeError: If *config* is not a supported subclass.
Examples
^^^^^^^^^^
.. code-block:: python
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf
config = AgentUDTFConfig(
name="my_agent",
instructions="You are a helpful assistant.",
model="openai/databricks-claude-3-7-sonnet",
)
setup_udtf(config)
spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))")
.. code-block:: python
from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf
config = MultiModelUDTFConfig(max_workers=20)
setup_udtf(config)
spark.sql(\"\"\"
SELECT * FROM multi_model_inference(
TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM prompts)
)
\"\"\")
"""
logger = get_logger()
if isinstance(config, AgentUDTFConfig):
udtf_name = name or f"run_{config.name}"
logger.info(f"Setting up agent UDTF: {udtf_name}")
return _setup_agent_udtf(
config=config, name=udtf_name, prompt_column=prompt_column
)
if isinstance(config, MultiModelUDTFConfig):
udtf_name = name or "multi_model_inference"
logger.info(f"Setting up multi-model UDTF: {udtf_name}")
return _setup_multi_model_udtf(config=config, name=udtf_name)
if isinstance(config, InferenceConfig):
udtf_name = name or "_inference"
logger.info(f"Setting up inference UDTF: {udtf_name}")
return _setup_inference_udtf(
config=config, name=udtf_name, prompt_column=prompt_column
)
raise TypeError(
f"Unsupported config type: {type(config).__name__}. "
"Use AgentUDTFConfig, MultiModelUDTFConfig, or InferenceConfig."
)
@track_usage
@for_loop_guardrail(min_interval_seconds=10)
[docs]
def run_agent_batch(
data_source: Union[DataFrame, str],
agent_config: AgentUDTFConfig,
prompt_column: str = "prompt",
output_table_name: Optional[str] = None,
dry_run: bool = None,
table_operation: Literal["overwrite", "append"] = "overwrite",
cost_component_name: str = None,
) -> dict:
"""Run an agent against every row in a DataFrame or table.
This is the high-level batch API (analogous to :func:`run_llm_batch`). It registers
a temporary UDTF, runs it against all input rows, and writes results to a Delta table
or returns them as a DataFrame.
There are two operation modes:
1. ``dry_run=True``: Returns results as a DataFrame without writing. Only works with
fewer than 1,000 rows.
2. ``dry_run=False``: Writes results to ``output_table_name``. Can ``overwrite`` or
``append`` (set via ``table_operation``).
.. attention:: You **must** pass a ``cost_component_name``, otherwise this function
will raise an exception.
:param data_source: Input DataFrame or fully-qualified Delta table name.
:param agent_config: :class:`AgentUDTFConfig` defining the agent.
:param prompt_column: Column name containing the user prompt. Defaults to ``"prompt"``.
:param output_table_name: Fully-qualified output Delta table name (required for batch mode).
:param dry_run: If True, returns results as DataFrame. If None, auto-detected based on row count.
:param table_operation: ``"overwrite"`` or ``"append"`` for the output table.
:param cost_component_name: Required cost component for telemetry.
:returns: Dict with ``source_uuid``, ``model``, ``df``, ``output_table_name``,
``cost_component_name``, and ``error_message``.
:raises ValueError: If ``cost_component_name`` is missing or ``output_table_name``
missing in batch mode.
:raises MLOpsToolkitTooManyRowsForInteractiveUsage: If ``dry_run=True`` with too many rows.
Examples
^^^^^^^^^^
.. code-block:: python
from ml_toolkit.functions.llm import AgentConfig, run_agent_batch
config = AgentConfig(
name="corp_insights",
instructions="You analyze corporate data for the {{database}} team.",
model="openai/databricks-claude-3-7-sonnet",
include_sql_tool=True,
include_web_search_tool=True,
)
result = run_agent_batch(
data_source="catalog.schema.input_prompts",
agent_config=config,
prompt_column="prompt",
output_table_name="catalog.schema.agent_output",
cost_component_name="my_team",
)
display(result["df"])
"""
result = run_udtf_batch(
data_source=data_source,
udtf_setup_fn=_setup_agent_udtf,
config=agent_config,
prompt_column=prompt_column,
output_table_name=output_table_name,
dry_run=dry_run,
table_operation=table_operation,
cost_component_name=cost_component_name,
model_name=agent_config.model,
)
# Add agent-specific metadata
result["agent_name"] = agent_config.name
return result