llm module#
The llm module of the ml_toolkit contains all the functions that enable LLM usage within Databricks notebooks.
We expose the following functions:
run_llm_batch: performs row-level LLM querying in our data, outputting their response to a new column.estimate_token_usage: estimates token usage of therun_llm_batchfunction.query: gives you the ability to directly query a model and give it context and tools to perform an action.run_vector_search: performs vector search (nearest or hybrid) over a Databricks vector search index.
PySpark UDTFs for at-scale inference:
setup_udtf: register any UDTF by passing the appropriate config (agent or multi-model).run_agent_batch: one-call batch execution of an agent UDTF with table writes.
Prompt management functions for the MLflow Prompt Registry:
create_or_fetch_prompt: create or fetch prompts with automatic deduplication.load_prompt: load a prompt by name, version, or URI/alias.search_prompts: search for prompts in the registry.set_prompt_alias: set aliases (e.g., “production”) for prompt versions.delete_prompt_version: delete a specific version of a prompt.delete_prompt: delete a prompt from the registry.
Serving endpoint management:
deploy_model_serving_endpoint: deploy Unity Catalog model versions to a new serving endpoint.add_model_version_to_endpoint: add a model version to an existing endpoint with automatic traffic redistribution.update_endpoint_traffic: update traffic distribution for A/B testing and canary deployments.get_serving_endpoint: get endpoint details.list_serving_endpoints: list all endpoints with optional tag filtering.query_serving_endpoint: query a specific version of a serving endpoint.
Attention
These functions are in an experimental phase and are subject to change. If you have any feedback, please submit it via our Jira Form.
Attention
You must pass a cost_component_name to the functions that call the LLMs, otherwise they will
raise exceptions.
Quota Controls#
We have set strict usage controls, or quotas, to limit usage. This means you will only be able to run a set limit of tokens without having an exception being raised. LLMs can get very expensive quickly, so these limits are set to avoid high usage.
If you have a valid usecase that solves a business problem and needs a higher quota to run that with your data, please
submit a ticket. The approval process for this involves getting your manager to approve your usecase and requested budget.
When submitting this, always include the output from estimate_token_usage.
Prompts#
The prompt functions provide a complete interface for managing prompts in the MLflow Prompt Registry. Prompts are versioned, support aliases (e.g., “production”, “staging”), and use content-based deduplication.
from ml_toolkit.functions.llm import create_or_fetch_prompt, load_prompt, set_prompt_alias
# Create or fetch a prompt (deduplicates automatically)
result = create_or_fetch_prompt(
name="mycatalog.myschema.qa_prompt",
template="Answer this question: {{question}}",
commit_message="Initial QA prompt"
)
# Set an alias for easy reference
set_prompt_alias("mycatalog.myschema.qa_prompt", alias="production", version=result["version"])
# Load by alias
prompt = load_prompt("prompts:/mycatalog.myschema.qa_prompt@production")
create_or_fetch_prompt#
- ml_toolkit.functions.llm.create_or_fetch_prompt(name: str, template: str, commit_message: str | None = None, tags: Dict[str, str] | None = None, create_if_missing: bool = True, **kwargs) dict | None[source]#
Create or fetch a prompt from MLflow Prompt Registry with automatic deduplication across all versions.
This function intelligently handles prompt creation by checking if an identical template already exists in ANY version (not just the latest). It uses a SHA256 hash of the template stored as a tag to avoid creating duplicate versions with the same content.
Key Benefits: - True Deduplication: Avoids creating duplicate versions by checking ALL existing versions - Handles Rollbacks: Returns existing version if you revert to an older template - Content-based lookup: Finds prompts by their content, not just name - Efficient versioning: Only creates new versions when template actually changes - Idempotent: Safe to call multiple times with same template
Documentation: https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/examples#register_prompt
How It Works#
Calculates SHA256 hash of the template (full 256-bit for guaranteed uniqueness)
Searches ALL existing versions of the prompt
Checks each version’s template_hash tag for a match
If any version matches, returns that version (no new version created)
If no match found, creates new version with hash tag
Template Syntax#
Use double curly braces
{{variable_name}}to define placeholders in your prompt template.template = "You are a {{role}}. Answer this question: {{question}}"
:param : Must be a valid Unity Catalog three-level namespace. :type : param name: Fully qualified name of the prompt in format “catalog.schema.prompt_name”. :param : :type : param template: The prompt template string. Use
{{variable_name}}syntax for placeholders. :param : :type : param commit_message: Optional message describing this version of the prompt. :param : Note: ‘template_hash’ will be automatically added to tags. :type : param tags: Optional dict of key-value tags for storing metadata. All values must be strings. :param : returns None when prompt doesn’t exist. :type : param create_if_missing: If True (default), creates new version if needed. If False, :param : Returns None if create_if_missing=False and prompt is not found. :type : returns: Dictionary containing prompt details including name, version, template, and tags. :param : :type : raises ValueError: If name or template is empty/None, or if name format is invalid.Examples
Basic usage - first call creates, second call reuses#from ml_toolkit.functions.llm import create_or_fetch_prompt # First call - creates version 1 result = create_or_fetch_prompt( name="mycatalog.myschema.qa_prompt", template="Answer this question: {{question}}", commit_message="Initial QA prompt" ) print(f"Version: {result['version']}") # 1 # Second call with same template - returns existing version 1 result = create_or_fetch_prompt( name="mycatalog.myschema.qa_prompt", template="Answer this question: {{question}}", # Same template ) print(f"Version: {result['version']}") # Still 1 (no new version) # Third call with different template - creates version 2 result = create_or_fetch_prompt( name="mycatalog.myschema.qa_prompt", template="Please provide an answer to: {{question}}", # Different commit_message="Updated wording" ) print(f"Version: {result['version']}") # 2 # Fourth call reverting to v1 template - returns v1 (not v3!) result = create_or_fetch_prompt( name="mycatalog.myschema.qa_prompt", template="Answer this question: {{question}}", # Same as v1 ) print(f"Version: {result['version']}") # Returns 1 (found existing!)
With tags for metadata tracking#from ml_toolkit.functions.llm import create_or_fetch_prompt result = create_or_fetch_prompt( name="mycatalog.myschema.summarizer", template="""Summarize the following text in {{num_sentences}} sentences: Text: {{content}} Focus on: {{focus_areas}}""", commit_message="Summarization prompt with focus areas", tags={ "tested_with": "gpt-4", "avg_latency_ms": "1200", "team": "content", "project": "summarization-v2" } ) print(f"Created/fetched version {result['version']}") print(f"Template hash: {result['tags']['template_hash']}")
Check if prompt exists without creating#from ml_toolkit.functions.llm import create_or_fetch_prompt # Try to fetch, but don't create if missing result = create_or_fetch_prompt( name="mycatalog.myschema.new_prompt", template="Some template: {{input}}", create_if_missing=False ) if result is None: print("Prompt doesn't exist or template changed") else: print(f"Found existing prompt version {result['version']}")
Idempotent prompt deployment#from ml_toolkit.functions.llm import create_or_fetch_prompt # Safe to run multiple times - only creates new version if template changes def deploy_prompt(): prompt = create_or_fetch_prompt( name="mycatalog.myschema.customer_support", template="""You are a helpful customer support assistant. Customer question: {{question}} Customer context: {{context}} Provide a helpful, professional response.""", commit_message="Customer support prompt deployment", tags={"environment": "production", "team": "support"} ) return prompt # First deployment - creates version 1 result1 = deploy_prompt() # Subsequent deployments - reuses version 1 (template unchanged) result2 = deploy_prompt() assert result1['version'] == result2['version']
load_prompt#
- ml_toolkit.functions.llm.load_prompt(name_or_uri: str, version: str | int | None = None, allow_missing: bool = False, link_to_model: bool = True, model_id: str | None = None) dict[source]#
Load a prompt from MLflow Prompt Registry.
This function retrieves a registered prompt from the MLflow prompt registry. You can load prompts by name with a specific version, or by using a prompt URI that includes the version or alias.
Documentation: https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/examples#load_prompt Loading Methods ^^^^^^^^^^^^^^^ 1. By name and version: Specify name and version separately 2. By URI with version: Use
prompts:/catalog.schema.prompt_name/version3. By URI with alias: Useprompts:/catalog.schema.prompt_name@alias:param : or a prompt name (e.g., “catalog.schema.prompt_name”) :type : param name_or_uri: Either a prompt URI (e.g., “prompts:/catalog.schema.prompt@production”) :param : Can be an integer or string. :type : param version: Specific version number to load. Ignored if name_or_uri includes version/alias. :param : Default is False. :type : param allow_missing: If True, returns None when prompt not found instead of raising an error. :param : :type : param link_to_model: Whether to link this prompt to a model. Default is True. :param : :type : param model_id: Optional model ID to link the prompt to. :param : Returns None if allow_missing=True and prompt is not found. :type : returns: Dictionary containing prompt details including name, version, template, and tags. :param : :type : raises ValueError: If name_or_uri is empty/None.
Examples
Load prompt by name and version#from ml_toolkit.functions.llm import load_prompt result = load_prompt( name_or_uri="mycatalog.myschema.qa_prompt", version=2 ) print(f"Template: {result['template']}") print(f"Version: {result['version']}")
Load prompt using URI with version#from ml_toolkit.functions.llm import load_prompt result = load_prompt( name_or_uri="prompts:/mycatalog.myschema.qa_prompt/2" ) print(f"Loaded version {result['version']}")
Load prompt using alias (e.g., production)#from ml_toolkit.functions.llm import load_prompt result = load_prompt( name_or_uri="prompts:/mycatalog.myschema.qa_prompt@production" ) print(f"Production version: {result['version']}")
Handle missing prompts gracefully#from ml_toolkit.functions.llm import load_prompt result = load_prompt( name_or_uri="mycatalog.myschema.nonexistent_prompt", version=1, allow_missing=True ) if result is None: print("Prompt not found") else: print(f"Found: {result['name']}")
search_prompts#
- ml_toolkit.functions.llm.search_prompts(filter_string: str | None = None, max_results: int | None = None) List[dict][source]#
Search for prompts in MLflow Prompt Registry.
This function searches for prompts in the MLflow prompt registry using SQL-like filter strings. For Unity Catalog, you must specify catalog and schema. The function returns a list of prompts matching the search criteria.
Documentation: https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/examples#search_prompts
Filter Limitations (Unity Catalog)#
Required: Must specify catalog and schema
Cannot filter by: Name patterns, tags, or exact prompt names
Workaround: Retrieve all prompts in a catalog/schema, then filter programmatically
:param : “catalog = ‘mycatalog’ AND schema = ‘myschema’” :type : param filter_string: Optional SQL-like filter string. For Unity Catalog: :param : :type : param max_results: Maximum number of prompts to return. If None, returns all matches. :param : :type : returns: List of dictionaries, each containing prompt details (name, latest_version, tags).
Examples
Search all prompts in a catalog and schema#from ml_toolkit.functions.llm import search_prompts prompts = search_prompts( filter_string="catalog = 'mycatalog' AND schema = 'myschema'" ) for prompt in prompts: print(f"{prompt['name']}: v{prompt['latest_version']}")
Limit search results#from ml_toolkit.functions.llm import search_prompts prompts = search_prompts( filter_string="catalog = 'mycatalog' AND schema = 'myschema'", max_results=10 ) print(f"Found {len(prompts)} prompts (max 10)")
Filter by name after retrieval (workaround for Unity Catalog)#from ml_toolkit.functions.llm import search_prompts # Get all prompts first all_prompts = search_prompts( filter_string="catalog = 'mycatalog' AND schema = 'myschema'" ) # Filter by name pattern programmatically qa_prompts = [p for p in all_prompts if 'qa' in p['name']] print(f"Found {len(qa_prompts)} QA prompts")
Search and display prompt details#from ml_toolkit.functions.llm import search_prompts prompts = search_prompts( filter_string="catalog = 'yd_production' AND schema = 'ml_prompts'" ) for prompt in prompts: print(f"Name: {prompt['name']}") print(f"Latest Version: {prompt['latest_version']}") if prompt['tags']: print(f"Tags: {prompt['tags']}") print("---")
set_prompt_alias#
- ml_toolkit.functions.llm.set_prompt_alias(name: str, alias: str, version: int) dict[source]#
Set an alias for a specific prompt version in MLflow Prompt Registry.
This function assigns an alias (like “production”, “staging”, or “latest”) to a specific version of a prompt. Aliases make it easier to reference stable versions of prompts without hardcoding version numbers.
Documentation: https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/examples#set_prompt_alias
Common Aliases#
production: The current production versionstaging: Version being tested before productionlatest: Most recent stable versionchampion: Best performing version
:param : :type : param name: Fully qualified name of the prompt in format “catalog.schema.prompt_name”. :param : :type : param alias: Alias name to assign (e.g., “production”, “staging”). :param : :type : param version: Version number to assign the alias to. Must be an integer. :param : :type : returns: Dictionary containing the name, alias, and version. :param : :type : raises ValueError: If name or alias is empty/None, or if version is not positive.
Examples
Set production alias#from ml_toolkit.functions.llm import set_prompt_alias result = set_prompt_alias( name="mycatalog.myschema.qa_prompt", alias="production", version=3 ) print(f"Set {result['alias']} alias to version {result['version']}")
Promote a version to production#from ml_toolkit.functions.llm import set_prompt_alias, load_prompt # Test version 5 in staging set_prompt_alias( name="mycatalog.myschema.customer_support", alias="staging", version=5 ) # After testing, promote to production set_prompt_alias( name="mycatalog.myschema.customer_support", alias="production", version=5 ) # Load using alias prompt = load_prompt("prompts:/mycatalog.myschema.customer_support@production")
Update alias to new version#from ml_toolkit.functions.llm import set_prompt_alias # Update production alias from v3 to v4 set_prompt_alias( name="mycatalog.myschema.summarizer", alias="production", version=4 # This replaces the previous production alias )
Set multiple aliases for different environments#from ml_toolkit.functions.llm import set_prompt_alias prompt_name = "mycatalog.myschema.translator" # Development version set_prompt_alias(prompt_name, alias="dev", version=7) # Staging version set_prompt_alias(prompt_name, alias="staging", version=6) # Production version set_prompt_alias(prompt_name, alias="production", version=5)
delete_prompt_version#
- ml_toolkit.functions.llm.delete_prompt_version(name: str, version: str | int) dict[source]#
Delete a specific version of a prompt from MLflow Prompt Registry.
This function permanently deletes a specific version of a prompt from the MLflow prompt registry. In Unity Catalog, all versions must be deleted before you can delete the prompt itself using
delete_prompt().Documentation: https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/examples#delete_prompt_version
Caution
This operation is permanent and cannot be undone. Make sure you have backups or exports of any prompt versions you want to preserve.
:param : :type : param name: Fully qualified name of the prompt in format “catalog.schema.prompt_name”. :param : :type : param version: Version to delete. Can be a string or integer (e.g., “2” or 2). :param : :type : returns: Dictionary containing the name and deleted version. :param : :type : raises ValueError: If name or version is empty/None, or if name format is invalid.
Examples
Delete a specific prompt version#from ml_toolkit.functions.llm import delete_prompt_version result = delete_prompt_version( name="mycatalog.myschema.qa_prompt", version="2" ) print(f"Deleted version {result['version']} of {result['name']}")
Delete old versions to clean up#from ml_toolkit.functions.llm import delete_prompt_version prompt_name = "mycatalog.myschema.customer_support" # Delete old versions that are no longer needed old_versions = ["1", "2", "3"] for ver in old_versions: try: delete_prompt_version(prompt_name, version=ver) print(f"Deleted version {ver}") except Exception as e: print(f"Failed to delete version {ver}: {e}")
Delete all versions before deleting prompt#from ml_toolkit.functions.llm import delete_prompt_version, delete_prompt prompt_name = "mycatalog.myschema.deprecated_prompt" versions_to_delete = ["1", "2", "3", "4", "5"] # Delete all versions for version in versions_to_delete: delete_prompt_version(prompt_name, version=version) print(f"Deleted version {version}") # Now safe to delete the prompt itself delete_prompt(prompt_name) print(f"Deleted prompt {prompt_name}")
Delete with error handling#from ml_toolkit.functions.llm import delete_prompt_version try: result = delete_prompt_version( name="mycatalog.myschema.test_prompt", version=3 ) print(f"Successfully deleted version {result['version']}") except Exception as e: print(f"Failed to delete version: {e}")
delete_prompt#
- ml_toolkit.functions.llm.delete_prompt(name: str) dict[source]#
Delete a prompt from MLflow Prompt Registry.
This function permanently deletes a prompt from the MLflow prompt registry. For Unity Catalog registries, you MUST delete all versions first before deleting the prompt itself.
Documentation: https://docs.databricks.com/aws/en/mlflow3/genai/prompt-version-mgmt/prompt-registry/examples#delete_prompt
Warning
Unity Catalog Requirement: All prompt versions must be deleted before deleting the prompt. Use
delete_prompt_version()to delete each version first. Otherwise, this operation will fail.Caution
This operation is permanent and cannot be undone. Make sure you have backups or exports of any prompts you want to preserve.
:param : :type : param name: Fully qualified name of the prompt in format “catalog.schema.prompt_name”. :param : :type : returns: Dictionary containing the name of the deleted prompt. :param : :type : raises ValueError: If name is empty/None or format is invalid. :param : :type : raises Exception: If prompt has undeleted versions (Unity Catalog).
Examples
Delete a prompt (after deleting all versions)#from ml_toolkit.functions.llm import delete_prompt_version, delete_prompt prompt_name = "mycatalog.myschema.old_prompt" # First, delete all versions delete_prompt_version(prompt_name, version="1") delete_prompt_version(prompt_name, version="2") delete_prompt_version(prompt_name, version="3") # Then delete the prompt result = delete_prompt(prompt_name) print(f"Deleted prompt: {result['name']}")
Delete all versions programmatically before deleting prompt#from ml_toolkit.functions.llm import ( search_prompts, load_prompt, delete_prompt_version, delete_prompt ) prompt_name = "mycatalog.myschema.deprecated_prompt" # Load the prompt to get version info prompt = load_prompt(prompt_name, version=1) # Delete all versions (you need to know how many exist) # In practice, you might need to query this information for version in range(1, 5): # versions 1-4 try: delete_prompt_version(prompt_name, version=str(version)) print(f"Deleted version {version}") except Exception as e: print(f"Version {version} not found or already deleted") # Delete the prompt delete_prompt(prompt_name)
Safe deletion with error handling#from ml_toolkit.functions.llm import delete_prompt try: result = delete_prompt("mycatalog.myschema.test_prompt") print(f"Successfully deleted: {result['name']}") except Exception as e: print(f"Failed to delete prompt: {e}") print("Make sure all versions are deleted first!")
UDTFs (User-Defined Table Functions)#
The udtf submodule provides PySpark UDTFs for running LLM inference at scale
inside Spark SQL queries. All UDTFs are registered through a single function,
setup_udtf, which dispatches based on the config type:
AgentUDTFConfig— Runs an OpenAI Agents SDK agent (with tools, Jinja2 instructions) per row.MultiModelUDTFConfig— Sends each prompt to multiple models in parallel and yields one row per model.InferenceConfig— Runs a single-model chat completion per row with prompt templating.ServingEndpointConfig— Queries a specific served model in a multi-version serving endpoint via REST API.
setup_udtf#
- ml_toolkit.functions.llm.udtf.interface.setup_udtf(config: ml_toolkit.functions.llm.udtf.config.UDTFConfig, name: str = None, prompt_column: str = 'prompt') type[source]#
Register a PySpark UDTF based on the config type.
Dispatches to the appropriate UDTF implementation based on the config class. After registration, the UDTF can be called from Spark SQL.
Supported config types:
AgentUDTFConfig— Runs an OpenAI Agents SDK agent per row.MultiModelUDTFConfig— Sends each prompt to multiple models in parallel, yielding one row per (input, model).
- Parameters:
config – A
UDTFConfigsubclass instance.name – UDTF name for SQL registration. Defaults vary by type:
run_{config.name}for agents,"multi_model_inference"for multi-model.prompt_column – Column name containing the user prompt. Defaults to
"prompt".
- Returns:
The UDTF class (also registered in SparkSession).
- Raises:
TypeError – If config is not a supported subclass.
Examples
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf config = AgentUDTFConfig( name="my_agent", instructions="You are a helpful assistant.", model="openai/databricks-claude-3-7-sonnet", ) setup_udtf(config) spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))")
from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf config = MultiModelUDTFConfig(max_workers=20) setup_udtf(config) spark.sql(""" SELECT * FROM multi_model_inference( TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM prompts) ) """)
Agent UDTF#
Register an agent as a PySpark UDTF, then call it from SQL or use the high-level batch API.
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf, run_agent_batch
config = AgentUDTFConfig(
name="my_agent",
instructions="You are a helpful assistant.",
model="openai/databricks-claude-3-7-sonnet",
tools=[my_tool_function],
)
setup_udtf(config)
results = spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))")
# Or use the high-level batch API with table writes:
result = run_agent_batch(
data_source="catalog.schema.prompts",
agent_config=config,
cost_component_name="my_team",
)
AgentUDTFConfig#
- class ml_toolkit.functions.llm.udtf.config.AgentUDTFConfig[source]#
Configuration for an OpenAI Agents SDK agent UDTF.
Extends
UDTFConfigwith agent-specific parameters: model, instructions, tools, and LiteLLM routing.- Parameters:
name – Unique name for this agent (used in UDTF registration and telemetry).
instructions – System instructions for the agent. Supports Jinja2 templates with variables from the input row (e.g.,
{{database}}).model – LiteLLM model identifier (e.g.,
"openai/gpt-5").tools – List of Python callables to register as agent tools. Each function should have type annotations and a docstring (used to auto-generate the tool schema).
include_sql_tool – Whether to include the built-in read-only SQL execution tool.
include_web_search_tool – Whether to include the built-in web search tool.
max_turns – Maximum number of agent turns (tool call rounds).
warehouse_id – Databricks SQL warehouse ID for the SQL tool. Resolved from
DATABRICKS_WAREHOUSE_IDenv var or Databricks secrets if None.litellm_base_url – LiteLLM proxy base URL. Resolved from Databricks secrets if None.
litellm_api_key – LiteLLM proxy API key. Resolved from Databricks secrets if None.
extra_agent_kwargs – Additional kwargs passed to the
agents.Agentconstructor.
Examples
Basic configuration with custom tools.#from ml_toolkit.functions.llm.udtf import AgentUDTFConfig def my_lookup(query: str) -> str: """Look up data in a custom API.""" return "result" config = AgentUDTFConfig( name="my_agent", instructions="You are a research assistant.", model="openai/gpt-5", tools=[my_lookup], # auto-wrapped with agents.function_tool() )
Configuration with built-in tools and Jinja2 instructions.#config = AgentUDTFConfig( name="corp_insights", instructions="You analyze data for the {{database}} team.", model="openai/databricks-claude-3-7-sonnet", include_sql_tool=True, include_web_search_tool=True, )
Note
Tools are plain Python callables with type annotations and docstrings. They are automatically wrapped with
agents.function_tool()at agent build time. The OpenAI Agents SDK uses the function signature and docstring to generate the tool schema that the LLM sees.
run_agent_batch#
- ml_toolkit.functions.llm.udtf.interface.run_agent_batch(data_source: pyspark.sql.DataFrame | str, agent_config: ml_toolkit.functions.llm.udtf.config.AgentUDTFConfig, prompt_column: str = 'prompt', output_table_name: str | None = None, dry_run: bool = None, table_operation: Literal['overwrite', 'append'] = 'overwrite', cost_component_name: str = None) dict[source]#
Run an agent against every row in a DataFrame or table.
This is the high-level batch API (analogous to
run_llm_batch()). It registers a temporary UDTF, runs it against all input rows, and writes results to a Delta table or returns them as a DataFrame.There are two operation modes:
dry_run=True: Returns results as a DataFrame without writing. Only works with fewer than 1,000 rows.dry_run=False: Writes results tooutput_table_name. Canoverwriteorappend(set viatable_operation).
Attention
You must pass a
cost_component_name, otherwise this function will raise an exception.- Parameters:
data_source – Input DataFrame or fully-qualified Delta table name.
agent_config –
AgentUDTFConfigdefining the agent.prompt_column – Column name containing the user prompt. Defaults to
"prompt".output_table_name – Fully-qualified output Delta table name (required for batch mode).
dry_run – If True, returns results as DataFrame. If None, auto-detected based on row count.
table_operation –
"overwrite"or"append"for the output table.cost_component_name – Required cost component for telemetry.
- Returns:
Dict with
source_uuid,model,df,output_table_name,cost_component_name, anderror_message.- Raises:
ValueError – If
cost_component_nameis missing oroutput_table_namemissing in batch mode.MLOpsToolkitTooManyRowsForInteractiveUsage – If
dry_run=Truewith too many rows.
Examples
from ml_toolkit.functions.llm import AgentConfig, run_agent_batch config = AgentConfig( name="corp_insights", instructions="You analyze corporate data for the {{database}} team.", model="openai/databricks-claude-3-7-sonnet", include_sql_tool=True, include_web_search_tool=True, ) result = run_agent_batch( data_source="catalog.schema.input_prompts", agent_config=config, prompt_column="prompt", output_table_name="catalog.schema.agent_output", cost_component_name="my_team", ) display(result["df"])
Multi-Model Inference UDTF#
Send each input row to multiple models in parallel and get one output row per (input, model). Useful for model comparison, consensus checks, and A/B testing.
from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf
config = MultiModelUDTFConfig(max_workers=20, max_retries=3)
setup_udtf(config)
# Input must have: prompt, models (comma-separated), optional system_prompt
results = spark.sql("""
SELECT * FROM multi_model_inference(
TABLE(
SELECT
prompt,
'openai/gpt-4o,openai/databricks-claude-3-7-sonnet' AS models
FROM my_prompts
)
)
""")
The UDTF yields one row per (input, model) with this schema:
Column |
Type |
Description |
|---|---|---|
|
string |
The model that produced this output |
|
string |
Cleaned response (markdown code fences stripped) |
|
string |
JSON of all non-prompt/models columns from the input row |
|
string |
Error details if the call failed (empty on success) |
|
string |
Unprocessed model response |
MultiModelUDTFConfig#
- class ml_toolkit.functions.llm.udtf.config.MultiModelUDTFConfig[source]#
Configuration for a multi-model inference UDTF.
Extends
UDTFConfigwith LiteLLM credentials and optional structured-output schema. The UDTF sends each input row to multiple models in parallel and yields one output row per (input, model).- Parameters:
litellm_base_url – LiteLLM proxy base URL. Resolved from
LITELLM_PROXY_BASE_URLenv var or Databricks secrets if None.litellm_api_key – LiteLLM proxy API key. Resolved from
LITELLM_API_KEYenv var or Databricks secrets if None.response_schema – Optional JSON-schema dict for structured output. When set, passed as
response_formatto the chat completions API.
Examples
from ml_toolkit.functions.llm.udtf import ( MultiModelUDTFConfig, setup_multi_model_udtf, ) config = MultiModelUDTFConfig(max_workers=20, max_retries=3) setup_multi_model_udtf(config) # Input must have: prompt, models (comma-separated), optional system_prompt spark.sql(""" SELECT * FROM multi_model_inference( TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM my_table) ) """)
Inference UDTF#
Run a single-model chat completion against every row. Supports <<column>>
prompt templates, system prompts, and any LiteLLM-compatible model.
from ml_toolkit.functions.llm.udtf import InferenceConfig, setup_udtf
config = InferenceConfig(
model="openai/gpt-5",
prompt_template="Tag the following vendor name: <<input>>. Options: <<candidates>>.",
system_prompt="You are a precise tagging assistant.",
input_column="input",
additional_context_columns=["candidates"],
max_output_tokens=256,
cost_component_name="my_team",
)
setup_udtf(config, name="vendor_tagger")
results = spark.sql("""
SELECT * FROM vendor_tagger(
TABLE(SELECT * FROM catalog.schema.eval_data)
)
""")
InferenceConfig#
- class ml_toolkit.functions.llm.udtf.config.InferenceConfig[source]#
Configuration for inference UDTF.
Extends
UDTFConfigwith model, prompt, and LiteLLM credentials for running chat completion inference.- Parameters:
model – LiteLLM model identifier (e.g.,
"gpt-5").prompt_template – Jinja2 template using
<<variable>>syntax. If None, the raw input column value is used as the prompt.system_prompt – Optional system prompt for the chat completion.
max_output_tokens – Maximum output tokens for model response.
input_column – Column name containing the input text.
additional_context_columns – Additional columns available in template context.
litellm_base_url – LiteLLM proxy base URL. Resolved from Databricks secrets if None.
litellm_api_key – LiteLLM proxy API key. Resolved from Databricks secrets if None.
litellm_model_kwargs – Additional kwargs for the chat completions API.
Serving Endpoint UDTF#
Query a specific served model version in a multi-version serving endpoint using
the Databricks REST API. This is the low-level UDTF behind
query_serving_endpoint() (see Serving Endpoints).
The endpoint name and served model can be set in the config or provided per-row, enabling dynamic routing across endpoints and versions.
from ml_toolkit.functions.llm.udtf import ServingEndpointConfig, setup_serving_endpoint_udtf
config = ServingEndpointConfig(
endpoint_name="my-endpoint",
served_model_name="v1",
)
setup_serving_endpoint_udtf(config, name="query_my_endpoint")
results = spark.sql("""
SELECT * FROM query_my_endpoint(
TABLE(SELECT prompt FROM catalog.schema.prompts)
)
""")
# DataFrame provides endpoint_name and served_model_name per row
config = ServingEndpointConfig() # no defaults needed
setup_serving_endpoint_udtf(config)
results = spark.sql("""
SELECT * FROM udtf_serving_endpoint(
TABLE(
SELECT
prompt,
endpoint_name,
served_model_name,
512 AS max_tokens,
0.7 AS temperature
FROM catalog.schema.routed_prompts
)
)
""")
ServingEndpointConfig#
- class ml_toolkit.functions.llm.udtf.config.ServingEndpointConfig[source]#
Configuration for serving endpoint UDTF.
Extends
UDTFConfigfor querying specific served models in multi-version serving endpoints using the Databricks REST API.- Parameters:
endpoint_name – Name of the serving endpoint (optional, can be provided per-row).
served_model_name – Served model name to route to (optional, can be provided per-row). Format:
"{name}-{endpoint_name}"(e.g.,"v1-my-endpoint").timeout_sec – Timeout in seconds for serving endpoint requests. Default is 900 seconds (15 minutes) to accommodate cold starts from scale-to-zero.
Examples
Configuration with default endpoint routing.#from ml_toolkit.functions.llm.udtf import ServingEndpointConfig config = ServingEndpointConfig( endpoint_name="my-endpoint", served_model_name="v1-my-endpoint", )
Note
Both
endpoint_nameandserved_model_namecan be provided per-row in the DataFrame, allowing dynamic routing across multiple endpoints and model versions.
setup_serving_endpoint_udtf#
- ml_toolkit.functions.llm.udtf.serving_endpoint.setup_serving_endpoint_udtf(config: ml_toolkit.functions.llm.udtf.config.ServingEndpointConfig, name: str = 'udtf_serving_endpoint', prompt_column: str = 'prompt')[source]#
Setup and register the serving endpoint UDTF.
Uses the generic setup_udtf from base.py for consistent behavior.
- Parameters:
config – ServingEndpointConfig instance
name – UDTF name for SQL registration
prompt_column – Column name containing the user prompt
- Returns:
The registered UDTF class
Functions#
run_llm_batch#
- ml_toolkit.functions.llm.run_llm_batch(data_source: pyspark.sql.DataFrame | str, prompt_source: str = None, output_column_name: str = 'llm_output', output_table_name: str | None = None, output_structured_schema: dict | None = None, model: str = DEFAULT_MODEL_PROCESS, dry_run: bool = None, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS, wait_for_completion: bool = True, table_operation: Literal['overwrite', 'append'] = 'overwrite', primary_key_columns: List[str] | None = None, cost_component_name: str = None, temperature: float = 0.7) dict[source]#
Performs a row-level operation on the input data (spark
DataFrameor table) by passing theprompt_source(along with the row-data mentioned there) to an LLM model and writes the output tooutput_column_name.There are two operations modes:
dry_run=True: good for quick experimentation, POCs and prompt engineering; only works with less than 1k rows of data.dry_run=False: should be used for the production pipeline or when running over a lot of data. Writes data tooutput_table_nameand can eitheroverwriteorappend(set viatable_operation).
Caution
dry_run=Truedoes incur usage and costs. It’s a mode designed to allow faster and cheaper experimentation. If all you want is to estimate usage on a bigger dataset, please runestimate_token_usage.Attention
We highly recommend using the default model or, if something different is needed, to go for the databricks llama models. They offer the best token throughput and cost. Also,
output_structured_schemais only available for llama models. This means you can provide a python dict with the desired schema of the LLM output and have that work out of the box.Attention
You must pass a
cost_component_name, otherwise this function will raise an exception.Prompt Building#
Express clearly what the model’s goal is and how they should approach their task. Don’t be overly wordy, as the prompt is wrapped around all rows, so token count grows fast. The way you can reference your data (the columns of your dataframe) is with the
<<col_name>>syntax. There’s an example below, but you can see more in the examples section.Warning
Do not use simple quotes (‘), because they break our prompt formatting.
Parameters:#
- param data_source:
DataFrame or Delta table name to run_llm_batch.
- param prompt_source:
String of the prompt.
- param output_column_name:
Name of the column to write the LLM output.
- param output_table_name:
Optional Delta table to write results to.
- param output_structured_schema:
Optional structured output dict (only available for llama models).
- param model:
Name of the LLM model to use.
- param max_output_tokens:
Maximum number of tokens the LLM can output.
- param dry_run:
Whether to run the processing job locally or triggering a remote batch run.
- param wait_for_completion:
Whether to wait for job completion (only applies in batch mode).
- param table_operation:
Operation to perform on the output table.
- param primary_key_columns:
Primary key columns for the output table.
- param cost_component_name:
Name of the cost component.
- param temperature:
Temperature parameter for the model (default 0.7).
- returns:
Result of the processing job.
- raises ValueError:
If output_table_name is not provided in batch mode or cost_component_name is missing.
- raises MLOpsToolkitTooManyRowsForInteractiveUsage:
If dry_run=True with too many rows.
Examples
Parsing a column to translate it’s content.#from ml_toolkit.functions.llm import run_llm_batch prompt = "You are an AI translator. Please translate the following text into english: <<text_col>>" run_llm_batch( data_source="catalog.schema.input_table", prompt_source=prompt, output_table_name="catalog.schema.output_table", output_column_name="text_col_en", table_operation="overwrite", cost_component_name=... # use your team's cost component here! )
Usingoutput_structured_schema#import pyspark.sql.functions as F from ml_toolkit.functions.llm import run_llm_batch output_schema = { "name": "Error evaluation", "schema": { "type": "object", "properties": { "is_human_error": {"type": "boolean"}, "confidence": {"type": "integer", "minimum": 0, "maximum": 10} } } } prompt = """ You are an expert python engineer. Your job is to look through error messages and output the source of the error and if the error looks like it came from a human error or not. Here is the record: Error: <<error>> """ res = run_llm_batch( data_source=df, prompt_source=prompt, max_output_tokens=64, cost_component_name=..., # use your team's cost component here! output_structured_schema=output_schema ) df_llm = res["df"] display(df_llm)
estimate_token_usage#
- ml_toolkit.functions.llm.estimate_token_usage(data_source: str | pyspark.sql.dataframe.DataFrame, prompt_source: str = None, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS, model: str = DEFAULT_MODEL_PROCESS) Tuple[int, int, int][source]#
Estimates the token usage of a given run of
run_llm_batch(data + prompt). Quotas are applied at the net-token level, which is our internal metric to account for the fact that output tokens are 4x more expensive than input tokens.Tip
Reduce the number of
max_output_tokensin order to greatly decrease cost.- Parameters:
data_source – DataFrame or Delta table name to run_llm_batch.
prompt_source – String of the prompt.
max_output_tokens – Maximum number of tokens the LLM can output.
model – Name of the LLM model to use.
- Returns total_input_tokens_estimate:
Estimate of total input token usage.
- Returns total_output_tokens_estimate:
Estimate of total output token usage.
- Returns total_net_tokens_estimate:
Estimate of total net token usage.
query#
- ml_toolkit.functions.llm.query(prompt: str, inputs: str | dict | pyspark.sql.dataframe.DataFrame | None = None, context_files: str | List[str] | None = None, model: str = DEFAULT_MODEL_QUERY, response_format: Literal['text', 'json'] = 'text', tools: list[Callable | dict] | None = None, max_tool_calls: int = None, max_output_tokens: int | None = 2048, max_tokens_param_name: str | None = 'max_tokens', system_prompt: str | None = None, cost_component_name: str | None = None, model_extra_options: dict = {}) LLMResponse[source]#
Query a LLM model and get its response.
This function allows you to prompt an LLM and implements many resources to streamline usage. Some examples are: input data (including dataframes), files, tools, web search. Check the examples to see advanced usage of these features.
- Parameters:
prompt – The main prompt text to send to the LLM model.
inputs – Optional input data that can be a string, dictionary, or DataFrame to provide context.
context_files – Optional file path(s) to include as context. Can be a single string or list of strings.
model – Optional name of the LLM model to use.
response_format – Format of the response - either “text” or “json”. Defaults to “text”.
tools – Optional list of functions or dicts that the LLM can use as tools.
max_tool_calls – Maximum number of tool calls allowed.
max_output_tokens – Maximum number of tokens in the response. Defaults to 2048.
max_tokens_param_name – Name of the argument to pass to the model for the maximum number of tokens(max_tokens/max_completion_tokens/max_output_tokens). Defaults to “max_tokens”.
system_prompt – Optional system prompt to set the behavior of the model.
cost_component_name – Name for tracking costs.
model_extra_options – Additional model-specific options as a dictionary.
- Returns:
LLMResponseobject containing the model’s response.- Raises:
ValueError – If there are errors reading context files.
Examples
Querying a search-enabled model.#from ml_toolkit.functions.llm import query resp = query( "What are the latest info on the Trump Tariffs?", model="gpt-4o-mini-search-preview", model_extra_options={ "web_search_options":{ "user_location": { "type": "approximate", "approximate": { "country": "US", } }, } }, cost_component_name=... # use your team's cost component here! ) print(resp.text)
Passing a dataframe as context and asking the model questions about our data.#from ml_toolkit.functions.llm import query df = spark.table(...).limit(10) response = query( "How can I query this table to know the usage rates by hour over the last 21 days?", inputs=[df], response_format="text", cost_component_name=... # use your team's cost component here! ) print(response.text)
Passing functions to increase the models’ capabilities.#from ml_toolkit.functions.llm import query def sum_two_numbers(a: float, b: float) -> float: """ Sums two numbers and returns the result. """ return a + b response = query( "What's the sum between 1.1234 and 4.4321?", tools=[sum_two_numbers], response_format="json", cost_component_name=... # use your team's cost component here! ) print(response.json())
LLMResponse#
- class ml_toolkit.functions.llm.query.LLMResponse[source]#
Class that defines the LLM response of the
queryfunction. It exposes the following attributes.text: returns the text responseresponse: returns the raw LLM response class (openai.ChatCompletion)
And also the following methods:
.json(): tries to parse the output into a pythondictif possible
run_vector_search#
- ml_toolkit.functions.llm.run_vector_search(data_source: pyspark.sql.DataFrame | str, index_name: str, search_column: str, output_columns: List[str] | None = None, output_table_name: str | None = None, num_results: int = 10, dry_run: bool = None, wait_for_completion: bool = True, table_operation: Literal['overwrite', 'append'] = 'overwrite', primary_key_columns: List[str] | None = None, cost_component_name: str = None, query_type: Literal['nearest', 'hybrid'] = 'nearest') dict[source]#
Performs vector search on the input data (spark
DataFrameor table) using a specified search index. The search is performed on the column specified bysearch_columnand results are written tooutput_table_name.The function supports two query types: -
nearest: Pure vector similarity search (default) - best for semantic similarity matching -hybrid: Combined vector and keyword search - useful when you want both semantic and exact keyword matchingThere are two operations modes:
dry_run=True: good for quick experimentation and testing; only works with less than 1k rows of data.dry_run=False: should be used for the production pipeline or when running over a lot of data. Writes data tooutput_table_nameand can eitheroverwriteorappend(set viatable_operation).
Caution
dry_run=Truedoes incur usage and costs. It’s a mode designed to allow faster and cheaper experimentation.Attention
You must pass a
cost_component_name, otherwise this function will raise an exception.Parameters:#
- param data_source:
DataFrame or Delta table name to run vector search on.
- param index_name:
Fully qualified name of the vector search index to use.
- param search_column:
Name of the column containing text to search.
- param output_columns:
Optional list of column names from the base DataFrame to return. If None, returns all columns.
- param output_table_name:
Optional Delta table to write results to.
- param num_results:
Number of results to return per search (default: 10).
- param dry_run:
Whether to run the processing job locally or triggering a remote batch run.
- param wait_for_completion:
Whether to wait for job completion (only applies in batch mode).
- param table_operation:
Operation to perform on the output table.
- param primary_key_columns:
Primary key columns for the output table.
- param cost_component_name:
Name of the cost component.
- param query_type:
Type of query to run. Must be either ‘nearest’ (default) for approximate nearest neighbor or ‘hybrid’ for combined vector and keyword search.
- returns:
Result of the processing job.
- raises ValueError:
If output_table_name is not provided in batch mode, cost_component_name is missing, or query_type is invalid.
- raises MLOpsToolkitTooManyRowsForInteractiveUsage:
If dry_run=True with too many rows.
Examples
Running vector search on a table.#from ml_toolkit.functions.llm import run_vector_search run_vector_search( data_source="catalog.schema.input_table", index_name="catalog.schema.search_index", search_column="text_to_search", return_columns=["id", "name", "description"], output_table_name="catalog.schema.output_table", table_operation="overwrite", query_type="nearest", # Use pure vector similarity search cost_component_name=... # use your team's cost component here! )
Running vector search interactively on a DataFrame.#import pyspark.sql.functions as F from ml_toolkit.functions.llm import run_vector_search df = spark.createDataFrame([ ("apple iphone 13",), ("samsung galaxy s21",), ], ["product_name"]) res = run_vector_search( data_source=df, index_name="catalog.schema.products_index", search_column="product_name", return_columns=["product_name", "category"], num_results=5, query_type="hybrid", # Use combined vector and keyword search cost_component_name=..., # use your team's cost component here! ) df_results = res["df"] display(df_results)
Using different query types for different use cases.## For semantic similarity search (recommended for most use cases) run_vector_search( data_source="catalog.schema.products", index_name="catalog.schema.product_embeddings", search_column="product_description", query_type="nearest", # Pure vector similarity output_table_name="catalog.schema.similar_products", cost_component_name=... ) # For search that combines semantic and keyword matching run_vector_search( data_source="catalog.schema.queries", index_name="catalog.schema.document_embeddings", search_column="search_query", query_type="hybrid", # Vector + keyword search output_table_name="catalog.schema.search_results", cost_component_name=... )
Serving Endpoints#
The serving_endpoints submodule provides functions for deploying and managing
Databricks serving endpoints for Unity Catalog LLM models. It supports single-version
and multi-version deployments with automatic traffic management using predefined
LLM serving standards.
LLM Serving Standards — five predefined tiers for endpoint configuration:
Standard |
Min tokens/sec |
Max tokens/sec |
Use case |
|---|---|---|---|
|
0 |
1,000 |
Dev / testing (scale-to-zero) |
|
1,000 |
5,000 |
Light production (default) |
|
5,000 |
10,000 |
Standard production |
|
10,000 |
30,000 |
High throughput |
|
30,000 |
70,000 |
Maximum throughput |
from ml_toolkit.functions.llm import (
deploy_model_serving_endpoint,
add_model_version_to_endpoint,
update_endpoint_traffic,
query_serving_endpoint,
)
# 1. Deploy a new endpoint (name inferred as "tiny_llama")
endpoint = deploy_model_serving_endpoint(
model_name="catalog.schema.tiny_llama",
model_versions=[1],
standard="XSMALL",
cost_component_name="ml_research",
wait_for_ready=True,
)
# 2. Add a second version with equal traffic
endpoint = add_model_version_to_endpoint(
version=2,
model_name="catalog.schema.tiny_llama",
standard="XSMALL",
)
# 3. Shift traffic to the new version
update_endpoint_traffic(
model_name="catalog.schema.tiny_llama",
traffic_config={1: 20, 2: 80},
)
# 4. Query a specific version
result_df = query_serving_endpoint(
df=spark.table("prompts"),
endpoint_name="tiny_llama",
version=2,
prompt_column="text",
)
deploy_model_serving_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.interface.deploy_model_serving_endpoint(model_name: str, model_versions: List[int], endpoint_name: str | None = None, endpoint_scaling_size: ml_toolkit.functions.llm.serving_endpoints.constants.EndpointScalingSize = 'SMALL', provisioning_type: ml_toolkit.functions.llm.serving_endpoints.constants.ProvisioningType | None = None, traffic_config: Dict[int, int] | None = None, cost_component_name: str | None = None, tags: Dict[str, str] | None = None, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#
Deploy one or more versions of a Unity Catalog model to a NEW serving endpoint.
This function creates a new serving endpoint. If the endpoint already exists with any served models, it will raise MLOpsToolkitEndpointAlreadyExistsException. Use add_model_version_to_endpoint() to add versions to an existing endpoint.
Automatically determines whether to use provisioned throughput or GPU config by calling the Databricks
get-model-optimization-infoAPI. For provisioned throughput models, token values are rounded to the model’s chunk_size increment. For GPU models, theendpoint_scaling_sizemaps to a workload_type + workload_size combo.:param : Example:
"yd_ml_dpe.fine_tune_staging.qlora_tinyllama_model":type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : Each entry corresponds to a version on the endpoint. Required.Example:
[1]for single version,[1, 2]for two versions. The served model names will be:v1,v2, etc.:type : param model_versions: List of Unity Catalog model version numbers (integers) to deploy. :param : the endpoint name will be automatically inferred from the model name (the last part
after the final dot). Must not already exist with served models, or MLOpsToolkitEndpointAlreadyExistsException will be raised. Example: If model_name is
"catalog.schema.my_model", endpoint_name defaults to"my_model":type : param endpoint_name: Optional name of the serving endpoint to create. If not provided, :param :
"XSMALL","SMALL","MEDIUM","LARGE","XLARGE".For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos: XSMALL=GPU_SMALL/Small, SMALL=GPU_MEDIUM/Small, MEDIUM=GPU_MEDIUM/Medium, LARGE=MULTIGPU_MEDIUM/Medium, XLARGE=GPU_MEDIUM_8/Large.
:type : param endpoint_scaling_size: Endpoint scaling size. Defaults to
"SMALL". Must be one of: :param : the type is auto-detected by calling the Databricks optimization-info API.Must be
"PROVISIONED_THROUGHPUT"or"GPU".:type : param provisioning_type: Optional override for provisioning type. If not provided, :param : to traffic percentages. If None, traffic is distributed equally.
Example:
{1: 70, 2: 30}means version 1 gets 70%, version 2 gets 30%:type : param traffic_config: Optional traffic distribution dict mapping version numbers :param : will attempt to determine from settings or environment, but will raise ValueError
if not found. Always specify this explicitly for production deployments.
:type : param cost_component_name: Required cost component for tracking. If not provided, :param : :type : param tags: Optional additional tags as key-value pairs :param : When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :type : param scale_to_zero_enabled: Whether to enable scale-to-zero for the endpoint (default: True). :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready before returning :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True
- Return type:
Endpoint details dictionary
Examples
Deploy provisioned throughput model (auto-detected)#from ml_toolkit.functions.llm.serving_endpoints import deploy_model_serving_endpoint endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.tiny_llama", model_versions=[1], endpoint_scaling_size="XSMALL", cost_component_name="ml_research", wait_for_ready=True, )
Deploy GPU model with explicit provisioning type#endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.custom_model", model_versions=[1], endpoint_scaling_size="MEDIUM", provisioning_type="GPU", cost_component_name="data_integration_rd", )
Deploy multiple versions with A/B testing#endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.llama_model", model_versions=[1, 2], traffic_config={1: 50, 2: 50}, endpoint_scaling_size="LARGE", cost_component_name="ml_research", )
add_model_version_to_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.interface.add_model_version_to_endpoint(version: int, model_name: str, endpoint_name: str | None = None, endpoint_scaling_size: ml_toolkit.functions.llm.serving_endpoints.constants.EndpointScalingSize = 'SMALL', provisioning_type: ml_toolkit.functions.llm.serving_endpoints.constants.ProvisioningType | None = None, traffic_percentage: int | None = None, redistribute_traffic: bool = True, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#
Add a new model version to an existing serving endpoint.
This function fetches the current endpoint configuration, auto-detects the model name and version from existing versions, and adds the new version with automatic traffic distribution.
Automatically determines whether to use provisioned throughput or GPU config. Validates that the new entity’s provisioning type matches the existing endpoint entities.
:param : The served model name will be:
v{version}(e.g., v1, v2, v3) :type : param version: Integer version identifier (e.g., 1, 2, 3). :param : Used to infer the endpoint name if endpoint_name is not explicitly provided. :type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : the endpoint name will be automatically inferred from model_name. :type : param endpoint_name: Optional name of the existing serving endpoint. If not provided, :param :"XSMALL","SMALL","MEDIUM","LARGE","XLARGE".For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos.
:type : param endpoint_scaling_size: Endpoint scaling size. Defaults to
"SMALL". Must be one of: :param : the type is auto-detected by calling the Databricks optimization-info API.Must be
"PROVISIONED_THROUGHPUT"or"GPU".:type : param provisioning_type: Optional override for provisioning type. If not provided, :param : If None and redistribute_traffic=True, traffic is distributed equally.
If specified, other versions’ traffic will be proportionally reduced.
:type : param traffic_percentage: Percentage of traffic to route to this new version (0-100). :param : If False, the new version gets 0% traffic (must manually call update_endpoint_traffic). :type : param redistribute_traffic: If True, automatically adjust traffic across all versions. :param : When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :type : param scale_to_zero_enabled: Whether to enable scale-to-zero for the new version (default: True). :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready before returning :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True
- Return type:
Updated endpoint details dictionary
Examples
Add version using model_name (recommended)#from ml_toolkit.functions.llm.serving_endpoints import add_model_version_to_endpoint endpoint = add_model_version_to_endpoint( version=3, model_name="catalog.schema.my_model", endpoint_scaling_size="MEDIUM", )
Add version with explicit provisioning type#endpoint = add_model_version_to_endpoint( version=2, endpoint_name="production_endpoint", endpoint_scaling_size="LARGE", provisioning_type="GPU", traffic_percentage=10, )
update_endpoint_traffic#
- ml_toolkit.functions.llm.serving_endpoints.interface.update_endpoint_traffic(traffic_config: Dict[int, int], model_name: str, endpoint_name: str | None = None, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#
Update the traffic distribution for an existing multi-version endpoint.
This is useful for gradual rollouts, canary deployments, or A/B testing scenarios where you want to adjust traffic percentages without redeploying models.
:param : Must sum to 100. Example:
{1: 30, 2: 70}:type : param traffic_config: Dict mapping version numbers to traffic percentages. :param : Used to infer the endpoint name if endpoint_name is not explicitly provided. :type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : the endpoint name will be automatically inferred from model_name. :type : param endpoint_name: Optional name of the serving endpoint. If not provided, :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready after update :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True- Return type:
Updated endpoint details dictionary
Examples
Gradually increase traffic to new version using model_name## Start: 90% version 1, 10% version 2 update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 90, 2: 10}, ) # After monitoring: 50% split update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 50, 2: 50}, ) # Final: 100% version 2 update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 0, 2: 100}, )
Using explicit endpoint name#update_endpoint_traffic( endpoint_name="my-endpoint", traffic_config={1: 50, 2: 50}, )
get_serving_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.interface.get_serving_endpoint(model_name: str, endpoint_name: str | None = None) Dict[source]#
Get details of a serving endpoint.
:param : Required parameter used to infer endpoint_name if not provided. :type : param model_name: Unity Catalog model name (catalog.schema.model_name). :param : automatically inferred from model_name. :type : param endpoint_name: Name of the endpoint. Optional - if not provided,
- Return type:
Endpoint details dictionary containing configuration, state, and metadata
Examples
from ml_toolkit.functions.llm.serving_endpoints import get_serving_endpoint # Endpoint name automatically inferred from model name endpoint = get_serving_endpoint(model_name="catalog.schema.tiny_llama") # Endpoint name inferred as "tiny_llama" print(f"State: {endpoint['state']['ready']}") print(f"Served entities: {endpoint['config']['served_entities']}") # Or specify explicit endpoint name endpoint = get_serving_endpoint( model_name="catalog.schema.tiny_llama", endpoint_name="my-custom-endpoint" )
list_serving_endpoints#
- ml_toolkit.functions.llm.serving_endpoints.interface.list_serving_endpoints(filter_tags: Dict[str, str] | None = None) List[Dict][source]#
List all serving endpoints in the workspace.
:param : ALL specified tags matching will be returned :type : param filter_tags: Optional dict of tags to filter by. Only endpoints with
- Return type:
List of endpoint dictionaries
Examples
from ml_toolkit.functions.llm.serving_endpoints import list_serving_endpoints # List all endpoints endpoints = list_serving_endpoints() for ep in endpoints: print(f"{ep['name']}: {ep['state']['ready']}") # Filter by cost component my_team_endpoints = list_serving_endpoints( filter_tags={"cost_component_name": "ml_research"} ) # Filter by multiple tags prod_endpoints = list_serving_endpoints( filter_tags={"env": "production", "team": "ai-platform"} )
query_serving_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.helpers.query.query_serving_endpoint(df: pyspark.sql.DataFrame, endpoint_name: str, version: int | None = None, prompt_column: str = 'prompt', prompt_template: str = None, max_tokens: int = 512, temperature: float = 0.7, output_column: str = 'output') pyspark.sql.DataFrame[source]#
Query a specific model version in a multi-version serving endpoint.
This convenience function prepares the DataFrame and calls the UDTF, automatically routing to the specified version (v1, v2, v3, etc.).
- Parameters:
df – Input DataFrame with a prompt column
endpoint_name – Name of the serving endpoint
version – Integer version identifier (e.g., 1, 2, 3). Defaults to None (latest version). Routes to served model “v{version}” (e.g., v1, v2, v3). If None, automatically resolves to the highest available version.
prompt_column – Name of the column containing prompts (default: “prompt”). Ignored if prompt_template is specified.
prompt_template – Optional template string with <<column>> placeholders. Use double angle brackets to reference DataFrame columns. Examples: “Classify: <<text>>”, “Extract entity: <<merchant_name>>”
max_tokens – Maximum output tokens (default: 512)
temperature – Sampling temperature (default: 0.7)
output_column – Name of the output column (default: “output”)
- Returns:
{output_column}, start_timestamp, end_timestamp, error_message
- Return type:
DataFrame with columns
Example
>>> from ml_toolkit.functions.llm.serving_endpoints.helpers.query import query_serving_endpoint >>> from pyspark.sql import SparkSession >>> >>> spark = SparkSession.builder.getOrCreate() >>> df = spark.table("input_table") >>> >>> # Query latest version (default) >>> result_df = query_serving_endpoint( ... df=df, ... endpoint_name="my_endpoint", ... prompt_column="text", ... max_tokens=1000, ... temperature=0.5 ... ) >>> >>> # Query specific version >>> result_df = query_serving_endpoint( ... df=df, ... endpoint_name="my_endpoint", ... version=1, ... prompt_template="Classify this <<category>> text: <<content>>", ... max_tokens=1000, ... temperature=0.5 ... )