llm module#
The llm module of the ml_toolkit contains all the functions that enable LLM usage within Databricks notebooks.
We expose the following functions:
run_fine_tuning: fine-tune an LLM using QLoRA with Optuna hyperparameter optimization, with local or remote GPU execution.run_llm_batch: performs row-level LLM querying in our data, outputting their response to a new column.estimate_token_usage: estimates token usage of therun_llm_batchfunction.query: gives you the ability to directly query a model and give it context and tools to perform an action.run_vector_search: performs vector search (nearest or hybrid) over a Databricks vector search index.run_inference: runs vLLM batch inference on Databricks Serverless GPU clusters using self-hosted Unity Catalog models.
PySpark UDTFs for at-scale inference:
setup_udtf: register any UDTF by passing the appropriate config (agent or multi-model).run_agent_batch: one-call batch execution of an agent UDTF with table writes.
Prompt management functions for the MLflow Prompt Registry:
create_or_fetch_prompt: create or fetch prompts with automatic deduplication.load_prompt: load a prompt by name, version, or URI/alias.search_prompts: search for prompts in the registry.set_prompt_alias: set aliases (e.g., “production”) for prompt versions.delete_prompt_version: delete a specific version of a prompt.delete_prompt: delete a prompt from the registry.
Serving endpoint management:
deploy_model_serving_endpoint: deploy Unity Catalog model versions to a new serving endpoint.add_model_version_to_endpoint: add a model version to an existing endpoint with automatic traffic redistribution.update_endpoint_traffic: update traffic distribution for A/B testing and canary deployments.get_serving_endpoint: get endpoint details.list_serving_endpoints: list all endpoints with optional tag filtering.query_serving_endpoint: query a specific version of a serving endpoint.
Model registration functions for Unity Catalog:
register_LLM_model: register a HuggingFace LLM (model + tokenizer) to Unity Catalog viamlflow.transformers.register_model: register a generic pyfunc model to Unity Catalog viamlflow.pyfunc.deploy_model: deploy a registered model version to a specified alias (e.g."prod").log_artifacts: copy a directory of artifacts to the model version’s UC Volume path.log_artifact: copy a single artifact to the model version’s UC Volume path.
Model loading and inspection functions:
get_model_info: fetch metadata for a registered model version (tags, alias, volume path).load_model_with_retry: load a model from UC or HuggingFace Hub with automatic retry logic.load_auto_tokenizer_with_retry: load a HuggingFace tokenizer with automatic retry logic.
Attention
These functions are in an experimental phase and are subject to change. If you have any feedback, please submit it via our Jira Form.
Attention
You must pass a cost_component_name to the functions that call the LLMs, otherwise they will
raise exceptions.
Quota Controls#
We have set strict usage controls, or quotas, to limit usage. This means you will only be able to run a set limit of tokens without having an exception being raised. LLMs can get very expensive quickly, so these limits are set to avoid high usage.
If you have a valid usecase that solves a business problem and needs a higher quota to run that with your data, please
submit a ticket. The approval process for this involves getting your manager to approve your usecase and requested budget.
When submitting this, always include the output from estimate_token_usage.
Prompts#
The prompt functions provide a complete interface for managing prompts in the MLflow Prompt Registry. Prompts are versioned, support aliases (e.g., “production”, “staging”), and use content-based deduplication.
from ml_toolkit.functions.llm import create_or_fetch_prompt, load_prompt, set_prompt_alias
# Create or fetch a prompt (deduplicates automatically)
result = create_or_fetch_prompt(
name="mycatalog.myschema.qa_prompt",
template="Answer this question: {{question}}",
commit_message="Initial QA prompt"
)
# Set an alias for easy reference
set_prompt_alias("mycatalog.myschema.qa_prompt", alias="production", version=result["version"])
# Load by alias
prompt = load_prompt("prompts:/mycatalog.myschema.qa_prompt@production")
create_or_fetch_prompt#
load_prompt#
search_prompts#
set_prompt_alias#
delete_prompt_version#
delete_prompt#
UDTFs (User-Defined Table Functions)#
The udtf submodule provides PySpark UDTFs for running LLM inference at scale
inside Spark SQL queries. All UDTFs are registered through a single function,
setup_udtf, which dispatches based on the config type:
AgentUDTFConfig— Runs an OpenAI Agents SDK agent (with tools, Jinja2 instructions) per row.MultiModelUDTFConfig— Sends each prompt to multiple models in parallel and yields one row per model.InferenceConfig— Runs a single-model chat completion per row with prompt templating.ServingEndpointConfig— Queries a specific served model in a multi-version serving endpoint via REST API.
setup_udtf#
- ml_toolkit.functions.llm.udtf.interface.setup_udtf(config: ml_toolkit.functions.llm.udtf.config.UDTFConfig, name: str = None, prompt_column: str = 'prompt') type[source]#
Register a PySpark UDTF based on the config type.
Dispatches to the appropriate UDTF implementation based on the config class. After registration, the UDTF can be called from Spark SQL.
Supported config types:
AgentUDTFConfig— Runs an OpenAI Agents SDK agent per row.MultiModelUDTFConfig— Sends each prompt to multiple models in parallel, yielding one row per (input, model).
- Parameters:
config – A
UDTFConfigsubclass instance.name – UDTF name for SQL registration. Defaults vary by type:
run_{config.name}for agents,"multi_model_inference"for multi-model.prompt_column – Column name containing the user prompt. Defaults to
"prompt".
- Returns:
The UDTF class (also registered in SparkSession).
- Raises:
TypeError – If config is not a supported subclass.
Examples
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf config = AgentUDTFConfig( name="my_agent", instructions="You are a helpful assistant.", model="openai/databricks-claude-3-7-sonnet", ) setup_udtf(config) spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))")
from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf config = MultiModelUDTFConfig(max_workers=20) setup_udtf(config) spark.sql(""" SELECT * FROM multi_model_inference( TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM prompts) ) """)
Agent UDTF#
Register an agent as a PySpark UDTF, then call it from SQL or use the high-level batch API.
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf, run_agent_batch
config = AgentUDTFConfig(
name="my_agent",
instructions="You are a helpful assistant.",
model="openai/databricks-claude-3-7-sonnet",
tools=[my_tool_function],
)
setup_udtf(config)
results = spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))")
# Or use the high-level batch API with table writes:
result = run_agent_batch(
data_source="catalog.schema.prompts",
agent_config=config,
cost_component_name="my_team",
)
AgentUDTFConfig#
- class ml_toolkit.functions.llm.udtf.config.AgentUDTFConfig[source]#
Configuration for an OpenAI Agents SDK agent UDTF.
Extends
UDTFConfigwith agent-specific parameters: model, instructions, tools, and LiteLLM routing.- Parameters:
name – Unique name for this agent (used in UDTF registration and telemetry).
instructions – System instructions for the agent. Supports Jinja2 templates with variables from the input row (e.g.,
{{database}}).model – LiteLLM model identifier (e.g.,
"openai/gpt-5").tools – List of Python callables to register as agent tools. Each function should have type annotations and a docstring (used to auto-generate the tool schema).
include_sql_tool – Whether to include the built-in read-only SQL execution tool.
include_web_search_tool – Whether to include the built-in web search tool.
max_turns – Maximum number of agent turns (tool call rounds).
warehouse_id – Databricks SQL warehouse ID for the SQL tool. Resolved from
DATABRICKS_WAREHOUSE_IDenv var or Databricks secrets if None.litellm_base_url – LiteLLM proxy base URL. Resolved from Databricks secrets if None.
litellm_api_key – LiteLLM proxy API key. Resolved from Databricks secrets if None.
extra_agent_kwargs – Additional kwargs passed to the
agents.Agentconstructor.
Examples
Basic configuration with custom tools.#from ml_toolkit.functions.llm.udtf import AgentUDTFConfig def my_lookup(query: str) -> str: """Look up data in a custom API.""" return "result" config = AgentUDTFConfig( name="my_agent", instructions="You are a research assistant.", model="openai/gpt-5", tools=[my_lookup], # auto-wrapped with agents.function_tool() )
Configuration with built-in tools and Jinja2 instructions.#config = AgentUDTFConfig( name="corp_insights", instructions="You analyze data for the {{database}} team.", model="openai/databricks-claude-3-7-sonnet", include_sql_tool=True, include_web_search_tool=True, )
Note
Tools are plain Python callables with type annotations and docstrings. They are automatically wrapped with
agents.function_tool()at agent build time. The OpenAI Agents SDK uses the function signature and docstring to generate the tool schema that the LLM sees.
run_agent_batch#
- ml_toolkit.functions.llm.udtf.interface.run_agent_batch(data_source: pyspark.sql.DataFrame | str, agent_config: ml_toolkit.functions.llm.udtf.config.AgentUDTFConfig, prompt_column: str = 'prompt', output_table_name: str | None = None, dry_run: bool = None, table_operation: Literal['overwrite', 'append'] = 'overwrite', cost_component_name: str = None) dict[source]#
Run an agent against every row in a DataFrame or table.
This is the high-level batch API (analogous to
run_llm_batch()). It registers a temporary UDTF, runs it against all input rows, and writes results to a Delta table or returns them as a DataFrame.There are two operation modes:
dry_run=True: Returns results as a DataFrame without writing. Only works with fewer than 1,000 rows.dry_run=False: Writes results tooutput_table_name. Canoverwriteorappend(set viatable_operation).
Attention
You must pass a
cost_component_name, otherwise this function will raise an exception.- Parameters:
data_source – Input DataFrame or fully-qualified Delta table name.
agent_config –
AgentUDTFConfigdefining the agent.prompt_column – Column name containing the user prompt. Defaults to
"prompt".output_table_name – Fully-qualified output Delta table name (required for batch mode).
dry_run – If True, returns results as DataFrame. If None, auto-detected based on row count.
table_operation –
"overwrite"or"append"for the output table.cost_component_name – Required cost component for telemetry.
- Returns:
Dict with
source_uuid,model,df,output_table_name,cost_component_name, anderror_message.- Raises:
ValueError – If
cost_component_nameis missing oroutput_table_namemissing in batch mode.MLOpsToolkitTooManyRowsForInteractiveUsage – If
dry_run=Truewith too many rows.
Examples
from ml_toolkit.functions.llm import AgentConfig, run_agent_batch config = AgentConfig( name="corp_insights", instructions="You analyze corporate data for the {{database}} team.", model="openai/databricks-claude-3-7-sonnet", include_sql_tool=True, include_web_search_tool=True, ) result = run_agent_batch( data_source="catalog.schema.input_prompts", agent_config=config, prompt_column="prompt", output_table_name="catalog.schema.agent_output", cost_component_name="my_team", ) display(result["df"])
Multi-Model Inference UDTF#
Send each input row to multiple models in parallel and get one output row per (input, model). Useful for model comparison, consensus checks, and A/B testing.
from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf
config = MultiModelUDTFConfig(max_workers=20, max_retries=3)
setup_udtf(config)
# Input must have: prompt, models (comma-separated), optional system_prompt
results = spark.sql("""
SELECT * FROM multi_model_inference(
TABLE(
SELECT
prompt,
'openai/gpt-4o,openai/databricks-claude-3-7-sonnet' AS models
FROM my_prompts
)
)
""")
The UDTF yields one row per (input, model) with this schema:
Column |
Type |
Description |
|---|---|---|
|
string |
The model that produced this output |
|
string |
Cleaned response (markdown code fences stripped) |
|
string |
JSON of all non-prompt/models columns from the input row |
|
string |
Error details if the call failed (empty on success) |
|
string |
Unprocessed model response |
MultiModelUDTFConfig#
- class ml_toolkit.functions.llm.udtf.config.MultiModelUDTFConfig[source]#
Configuration for a multi-model inference UDTF.
Extends
UDTFConfigwith LiteLLM credentials and optional structured-output schema. The UDTF sends each input row to multiple models in parallel and yields one output row per (input, model).- Parameters:
litellm_base_url – LiteLLM proxy base URL. Resolved from
LITELLM_PROXY_BASE_URLenv var or Databricks secrets if None.litellm_api_key – LiteLLM proxy API key. Resolved from
LITELLM_API_KEYenv var or Databricks secrets if None.response_schema – Optional JSON-schema dict for structured output. When set, passed as
response_formatto the chat completions API.
Examples
from ml_toolkit.functions.llm.udtf import ( MultiModelUDTFConfig, setup_multi_model_udtf, ) config = MultiModelUDTFConfig(max_workers=20, max_retries=3) setup_multi_model_udtf(config) # Input must have: prompt, models (comma-separated), optional system_prompt spark.sql(""" SELECT * FROM multi_model_inference( TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM my_table) ) """)
Inference UDTF#
Run a single-model chat completion against every row. Supports <<column>>
prompt templates, system prompts, and any LiteLLM-compatible model.
from ml_toolkit.functions.llm.udtf import InferenceConfig, setup_udtf
config = InferenceConfig(
model="openai/gpt-5",
prompt_template="Tag the following vendor name: <<input>>. Options: <<candidates>>.",
system_prompt="You are a precise tagging assistant.",
input_column="input",
additional_context_columns=["candidates"],
max_output_tokens=256,
cost_component_name="my_team",
)
setup_udtf(config, name="vendor_tagger")
results = spark.sql("""
SELECT * FROM vendor_tagger(
TABLE(SELECT * FROM catalog.schema.eval_data)
)
""")
InferenceConfig#
- class ml_toolkit.functions.llm.udtf.config.InferenceConfig[source]#
Configuration for inference UDTF.
Extends
UDTFConfigwith model, prompt, and LiteLLM credentials for running chat completion inference.- Parameters:
model – LiteLLM model identifier (e.g.,
"gpt-5").prompt_template – Jinja2 template using
<<variable>>syntax. If None, the raw input column value is used as the prompt.system_prompt – Optional system prompt for the chat completion.
max_output_tokens – Maximum output tokens for model response.
input_column – Column name containing the input text.
additional_context_columns – Additional columns available in template context.
litellm_base_url – LiteLLM proxy base URL. Resolved from Databricks secrets if None.
litellm_api_key – LiteLLM proxy API key. Resolved from Databricks secrets if None.
litellm_model_kwargs – Additional kwargs for the chat completions API.
Serving Endpoint UDTF#
Query a specific served model version in a multi-version serving endpoint using
the Databricks REST API. This is the low-level UDTF behind
query_serving_endpoint() (see Serving Endpoints).
The endpoint name and served model can be set in the config or provided per-row, enabling dynamic routing across endpoints and versions.
from ml_toolkit.functions.llm.udtf import ServingEndpointConfig, setup_serving_endpoint_udtf
config = ServingEndpointConfig(
endpoint_name="my-endpoint",
served_model_name="v1",
)
setup_serving_endpoint_udtf(config, name="query_my_endpoint")
results = spark.sql("""
SELECT * FROM query_my_endpoint(
TABLE(SELECT prompt FROM catalog.schema.prompts)
)
""")
# DataFrame provides endpoint_name and served_model_name per row
config = ServingEndpointConfig() # no defaults needed
setup_serving_endpoint_udtf(config)
results = spark.sql("""
SELECT * FROM udtf_serving_endpoint(
TABLE(
SELECT
prompt,
endpoint_name,
served_model_name,
512 AS max_tokens,
0.7 AS temperature
FROM catalog.schema.routed_prompts
)
)
""")
ServingEndpointConfig#
- class ml_toolkit.functions.llm.udtf.config.ServingEndpointConfig[source]#
Configuration for serving endpoint UDTF.
Extends
UDTFConfigfor querying specific served models in multi-version serving endpoints using the Databricks REST API.- Parameters:
endpoint_name – Name of the serving endpoint (optional, can be provided per-row).
served_model_name – Served model name to route to (optional, can be provided per-row). Format:
"{name}-{endpoint_name}"(e.g.,"v1-my-endpoint").timeout_sec – Timeout in seconds for serving endpoint requests. Default is 900 seconds (15 minutes) to accommodate cold starts from scale-to-zero.
Examples
Configuration with default endpoint routing.#from ml_toolkit.functions.llm.udtf import ServingEndpointConfig config = ServingEndpointConfig( endpoint_name="my-endpoint", served_model_name="v1-my-endpoint", )
Note
Both
endpoint_nameandserved_model_namecan be provided per-row in the DataFrame, allowing dynamic routing across multiple endpoints and model versions.
setup_serving_endpoint_udtf#
- ml_toolkit.functions.llm.udtf.serving_endpoint.setup_serving_endpoint_udtf(config: ml_toolkit.functions.llm.udtf.config.ServingEndpointConfig, name: str = 'udtf_serving_endpoint', prompt_column: str = 'prompt')[source]#
Setup and register the serving endpoint UDTF.
Uses the generic setup_udtf from base.py for consistent behavior.
- Parameters:
config – ServingEndpointConfig instance
name – UDTF name for SQL registration
prompt_column – Column name containing the user prompt
- Returns:
The registered UDTF class
Fine-Tuning#
The run_fine_tuning function provides end-to-end LLM fine-tuning using QLoRA
with Optuna hyperparameter optimization. It loads training data from a Unity Catalog
table, trains a QLoRA adapter, registers the merged model to Unity Catalog, and
returns training metrics. Supports both local GPU execution and remote Databricks
serverless/classic GPU clusters.
from ml_toolkit.functions.llm import run_fine_tuning
REQUIREMENTS = [
"transformers==4.57.6", "datasets==4.5.0", "accelerate==1.12.0",
"peft==0.18.1", "bitsandbytes==0.49.1", "safetensors==0.7.0",
"threadpoolctl==3.6.0", "optuna==4.7.0",
]
result = run_fine_tuning(
training_table_name="catalog.schema.training_data",
model_name="catalog.schema.my_fine_tuned_model",
base_model="qwen2_5_3b_instruct",
requirements=REQUIREMENTS,
)
print(f"Training loss: {result['training_loss']:.4f}")
print(f"Model version: {result['model_version']}")
from ml_toolkit.functions.llm import run_fine_tuning
from ml_toolkit.ml.llm.fine_tuning.config import (
DataConfig, LoRAConfig, TrainingConfig,
)
result = run_fine_tuning(
training_table_name="catalog.schema.vendor_data",
model_name="catalog.schema.vendor_tagger",
base_model="qwen2_5_7b_instruct",
requirements=REQUIREMENTS + ["jinja2"],
adapter_config=LoRAConfig(
r=32, # LoRA rank (default: 16)
alpha=64, # scaling factor (default: 2*r)
dropout=0.1, # dropout rate (default: 0.05)
target_modules="all-linear", # (default: "all-linear")
),
training_config=TrainingConfig(
max_steps=300, # steps per trial (default: 100)
n_trials=10, # Optuna trials (default: 3)
max_seq_length=1024, # max tokens (default: 256)
seed=42, # (default: 9)
eval_steps=10, # (default: 1)
logging_steps=5, # (default: 1)
compute_dtype="bf16", # (default: "auto")
tier="large", # override auto-detected tier
),
data_config=DataConfig(
text_column="description", # (default: "text")
label_column="merchant", # (default: "label")
id_column="row_uuid", # (default: "id")
max_train_samples=5000, # (default: 100)
train_test_val_split=(0.7, 0.15, 0.15), # (default: (0.6, 0.2, 0.2))
data_source_filter="bank_of_america", # (default: None)
data_source_column="data_source", # (default: None)
prompt_registry_name="catalog.schema.my_prompt",
prompt_alias="production",
),
env_vars={"TOKENIZERS_PARALLELISM": "false"},
# --- Remote execution ---
trigger_remote=True, # submit as remote job (default: False)
cluster_type="serverless_gpu", # "serverless_gpu" or "classic_gpu" (default: "serverless_gpu")
gpu_type="h100", # "a10" or "h100" (default: "a10")
num_gpus=1, # number of GPUs (default: 1)
wait=True, # False returns RemoteJobRun handle (default: True)
)
result = run_fine_tuning(
training_table_name="catalog.schema.training_data",
model_name="catalog.schema.my_model",
base_model="qwen2_5_3b_instruct",
requirements=REQUIREMENTS,
trigger_remote=True, # submit as remote job
cluster_type="serverless_gpu", # (default: "serverless_gpu")
gpu_type="a10", # "a10" or "h100" (default: "a10")
num_gpus=1, # (default: 1)
wait=True, # False returns RemoteJobRun handle
)
Model signature — by default no MLflow signature is set on the registered model. You can provide custom input/output schemas as PySpark types or DDL strings. This controls how Databricks Model Serving validates request/response payloads.
from pyspark.sql.types import StructType, StructField, StringType, ArrayType
result = run_fine_tuning(
training_table_name="catalog.schema.training_data",
model_name="catalog.schema.my_ner_model",
base_model="qwen2_5_3b_instruct",
requirements=REQUIREMENTS,
input_schema=StructType([
StructField("text", StringType(), nullable=False),
StructField("context", StringType(), nullable=True),
]),
output_schema=ArrayType(StringType()),
)
result = run_fine_tuning(
training_table_name="catalog.schema.training_data",
model_name="catalog.schema.my_model",
base_model="qwen2_5_3b_instruct",
requirements=REQUIREMENTS,
input_schema="struct<prompt:string>",
output_schema="array<string>",
)
Available base models — pass any key as base_model, or use a direct HuggingFace model ID
(set tier in TrainingConfig when using a custom HF ID):
Key |
HuggingFace ID |
Tier |
|---|---|---|
|
Qwen/Qwen3-30B-A3B |
xlarge |
|
Qwen/Qwen3-14B |
large |
|
Qwen/Qwen3-8B |
large |
|
Qwen/Qwen3-4B |
medium |
|
Qwen/Qwen3-4B-Instruct-2507 |
medium |
|
Qwen/Qwen2.5-7B-Instruct |
large |
|
Qwen/Qwen2.5-3B-Instruct |
medium |
|
Qwen/Qwen2.5-1.5B-Instruct |
small |
|
Qwen/Qwen2.5-0.5B-Instruct |
xsmall |
|
HuggingFaceTB/SmolLM2-360M-Instruct |
xsmall |
|
HuggingFaceTB/SmolLM2-135M-Instruct |
xsmall |
|
TinyLlama/TinyLlama-1.1B-Chat-v1.0 |
small |
|
deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B |
small |
Using any HuggingFace model — you are not limited to the catalog above. Pass any HuggingFace
model ID directly as base_model (e.g., "meta-llama/Llama-3.1-8B-Instruct"). When using
a custom model ID, set tier in TrainingConfig to control batch sizes and gradient
accumulation (defaults to "medium" if not set):
result = run_fine_tuning(
training_table_name="catalog.schema.training_data",
model_name="catalog.schema.my_model",
base_model="meta-llama/Llama-3.1-8B-Instruct", # any HuggingFace model ID
requirements=REQUIREMENTS,
training_config=TrainingConfig(tier="large"), # set tier for resource allocation
)
Tier controls batch sizes and gradient accumulation in the Optuna HP search space:
Tier |
Batch sizes |
Gradient accumulation |
|---|---|---|
|
[32, 64, 128] |
[1, 2, 4] |
|
[16, 32, 64] |
[1, 2, 4, 8] |
|
[4, 8, 16] |
[2, 4, 8] |
|
[1, 2, 4] |
[8, 16, 32] |
|
[1, 2] |
[16, 32, 64] |
Prompt template options — three ways to provide a prompt:
str.format (default) — uses
{text}placeholder, only thetext_columnvalue is available:DataConfig(prompt_template="Extract entities from: {text}\n\nAnswer:\n")
Jinja2 — uses
<<column>>delimiters, all DataFrame columns are available. Auto-detected when template contains<<,{%, or{#. Add"jinja2"to requirements.DataConfig(prompt_template="Transaction: <<text>>\nCandidates: <<candidates>>\nAnswer:\n")
MLflow Prompt Registry — loaded at runtime,
{{variable}}is auto-converted to<<variable>>. Cannot be combined withprompt_template.DataConfig( prompt_registry_name="catalog.schema.my_prompt", prompt_alias="production", # or prompt_version=1 )
run_fine_tuning#
LoRAConfig#
TrainingConfig#
DataConfig#
Functions#
run_llm_batch#
estimate_token_usage#
query#
LLMResponse#
- class ml_toolkit.functions.llm.query.LLMResponse[source]#
Class that defines the LLM response of the
queryfunction. It exposes the following attributes.text: returns the text responseresponse: returns the raw LLM response class (openai.ChatCompletion)
And also the following methods:
.json(): tries to parse the output into a pythondictif possible
run_vector_search#
GPU Inference#
The run_inference function runs vLLM batch inference on Databricks Serverless
GPU clusters using self-hosted Unity Catalog models. It wraps
run_serverless_inference and submits the work as a remote job via
run_remote(), so you can launch inference from any notebook — no GPU cluster
required on the caller side.
The function validates inputs (table names, GPU config), serializes parameters, and submits a one-time Databricks workflow that provisions GPUs, runs vLLM inference via Ray Data, and writes results to a Delta table.
from ml_toolkit.functions.llm import run_inference
remote_run = run_inference(
"catalog.schema.input_table",
"catalog.schema.my_model",
"catalog.schema.output",
uc_volumes_path="/Volumes/catalog/schema/vol",
)
# Track in the Databricks UI
print(remote_run.databricks_url)
print(remote_run.job_run_id)
# When ready, block until the job finishes
result = remote_run.get_result()
result = run_inference(
"catalog.schema.input_table",
"catalog.schema.my_model",
"catalog.schema.output",
uc_volumes_path="/Volumes/catalog/schema/vol",
wait=True,
)
run_inference(
"catalog.schema.input_table",
"catalog.schema.my_model",
"catalog.schema.output",
uc_volumes_path="/Volumes/catalog/schema/vol",
trigger_remote=False,
)
remote_run = run_inference(
"catalog.schema.input_table",
"catalog.schema.my_model",
"catalog.schema.output",
uc_volumes_path="/Volumes/catalog/schema/vol",
num_gpus=8,
gpu_type="h100",
temperature=0.0,
max_new_tokens=1024,
tensor_parallel_size=8,
)
GPU types and constraints:
GPU type |
|
Notes |
|---|---|---|
|
Any (default: 4) |
Good for models up to ~8B parameters. Default |
|
Multiple of 8 |
Required for large models (>8B). |
Auto-loading behaviour: When the model was registered with model_config
and prompt_uris via register_LLM_model, generation parameters
(temperature, max_new_tokens, max_model_len, etc.) and prompt
templates are loaded automatically. Explicit parameters always take precedence.
run_inference#
- ml_toolkit.functions.llm.inference.interface.run_inference(input_source: str | pyspark.sql.DataFrame, model_name: str, output_table: str, *, prompt_column: str = 'prompt', model_version: int | Literal['latest'] | None = None, model_alias: str | None = None, system_prompt: str | None = None, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, top_k: int | None = None, min_p: float | None = None, presence_penalty: float | None = 0.0, repetition_penalty: float | None = 1.0, max_model_len: int | None = None, num_gpus: int = 4, gpu_type: str = 'a10', concurrency: int | None = None, batch_size: int | None = None, tensor_parallel_size: int = 1, gpu_memory_utilization: float | None = 0.85, max_num_seqs: int | None = None, max_num_batched_tokens: int | None = None, kv_cache_dtype: str | None = None, dtype: str | None = None, quantization: str | None = None, uc_volumes_path: str | None = None, auto_build_prompt: bool = True, output_column: str = 'model_output', pk_column: str | None = None, table_operation: Literal['overwrite', 'append'] | None = 'overwrite', ray_object_store_memory: float = 0.5, http_pool_maxsize: int = 64, trigger_remote: bool = True, wait: bool = False, compute_type: Literal['serverless', 'classic', 'anyscale'] = 'serverless', capacity_reservation_id: str | None = 'auto', availability_zone: str | None = 'auto', ray_kwargs: dict | None = None, env_vars: dict[str, str] | None = None, output_schema: object | None = None, chat_template_kwargs: dict | None = None, output_storage_location: str | None = None, output_bucket: str | None = None, checkpoint_path: str | None = None, anyscale_kwargs: AnyscaleKwargs | dict | None = None) ml_toolkit.functions.llm.inference.function.RemoteInferenceRun | dict | None[source]#
Run vLLM batch inference on Databricks GPU compute.
Supports two compute backends via
compute_type:"serverless"(default) — Databricks Serverless GPU. Uses A10 or H100 accelerators. Requiresuc_volumes_pathfor data staging."classic"— Databricks classic cluster with p5.48xlarge (8x H100) or p6-b200.48xlarge (8x B200) nodes. Uses Ray-on-Spark for distributed inference. GPUs in multiples of 8. The driver contributes 8 GPUs; the number of additional worker nodes is auto-computed asmax(0, (num_gpus - 8) // 8). Supports AWS Capacity Blocks.
When
trigger_remote=True(default), the function validates inputs, serializes parameters, and submits a remote job viarun_remote().When
trigger_remote=False, inference runs directly on the current cluster (must already have GPUs available).- Parameters:
input_source – Fully qualified UC table name (
catalog.schema.table) or DataFrame.model_name – UC model name (
catalog.schema.model).output_table – Fully qualified output table name.
prompt_column – Column containing prompt text. Defaults to
"prompt".model_version – Specific model version (int) or
"latest". Mutually exclusive withmodel_alias.None= production alias.model_alias – UC alias (e.g.
"champion"). Mutually exclusive withmodel_version.system_prompt – System instruction for chat completions.
max_new_tokens – Maximum tokens to generate per prompt.
temperature – Sampling temperature (0.0 = greedy).
top_p – Nucleus sampling threshold.
top_k – Controls the number of top tokens to consider. Set to 0 (or -1) to consider all tokens.
min_p – Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1].
presence_penalty – Penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.
repetition_penalty – Penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.
max_model_len – Maximum sequence length.
num_gpus – Number of GPUs. Defaults to 4. For
compute_type="classic", must be a multiple of 8 (H100 only) and counts the total across driver + workers.num_gpus=8runs driver-only;num_gpus=16runs driver + 1 worker node.gpu_type – GPU type —
"a10","h100", or"b200". Defaults to"a10". Forcompute_type="classic", must be"h100"or"b200".concurrency – Number of vLLM data-parallel replicas.
batch_size – Rows per inference batch.
tensor_parallel_size – GPUs per vLLM instance.
gpu_memory_utilization – Fraction of GPU memory to use for vLLM.
max_num_seqs – Maximum concurrent sequences per vLLM engine.
max_num_batched_tokens – Maximum tokens per vLLM iteration step.
kv_cache_dtype – Data type for KV cache (e.g.
"fp8").dtype – Model weight data type for vLLM (e.g.
"bfloat16","float16","auto").None= auto-resolved from model config.quantization – vLLM quantization method.
"auto"(default) = resolved from model_config → GPU defaults ("fp8"for H100/B200,Nonefor A10)."fp8"= dynamically quantize weights to FP8_E4M3 at load time.None= don’t pass quantization to vLLM (auto-detect from model).uc_volumes_path – UC Volume path for staging data. Required.
auto_build_prompt – Auto-build prompt column from model template.
output_column – Name of the output column containing model predictions. Defaults to
"model_output".pk_column – Name of an existing column to use as the row identifier for joining inference results back to the input. If
None(default), a synthetic row ID is generated automatically. Use this when your input already has a unique primary key column to avoid creating a redundant one.table_operation – How to write the output table —
"overwrite"(default) replaces the table,"append"adds rows to an existing table.ray_object_store_memory – Fraction of cluster memory for Ray object store. Defaults to 0.5.
http_pool_maxsize – Max concurrent HTTP connections for model download.
trigger_remote – If
True, submit viarun_remote(). IfFalse, run directly on current cluster.wait – If
Trueandtrigger_remote=True, block until completion and return the result dict.compute_type –
"serverless"(default) or"classic".capacity_reservation_id – AWS Capacity Block reservation ID (e.g.
"cr-0abc1234def56789a"). Classic only. Configures the cluster with on-demand instances and theX-Databricks-AwsCapacityBlockIdtag. Pass"auto"to fetch from theWORKSPACE_CONFIGURATIONsecret scope (keysGPU_CAPACITY_RESERVATION_IDandGPU_AVAILABILITY_ZONE).availability_zone – AWS availability zone for the Capacity Block (e.g.
"us-east-1a"). Required when usingcapacity_reservation_id(auto-resolved when"auto").ray_kwargs – Dict of Ray-on-Spark tuning parameters forwarded to
run_classic_inference(classic only). Supported keys:num_cpus_per_node,num_cpus_head_node,num_gpus_head_node.env_vars – Optional dictionary of environment variables to set in each Ray worker process before inference starts. Ray actors do not inherit the driver’s environment, so diagnostic variables like
VLLM_LOGGING_LEVEL=DEBUGmust be forwarded explicitly. Defaults toNone(no extra variables).output_schema – Optional PySpark
DataTypedescribing the JSON shape the model is expected to produce (e.g.ArrayType(StructType([...]))). When set, the output table gets a parsedmodel_outputstruct column alongside the raw text inmodel_output_raw. Accepts aDataTypeinstance, a DDL string (e.g."array<struct<l1_category:string,...>>"), or the dict form produced byml_toolkit.ops.helpers.mlflow.serialize_pyspark_schema(). Required for HuggingFace model IDs (no MLflow signature); for UC models, overrides whatever the MLflow signature declared.chat_template_kwargs – Extra keyword arguments forwarded to the tokenizer’s
apply_chat_templatecall when prompts are rendered. Use this to toggle model-family-specific flags such as{"enable_thinking": False}for Qwen3 (suppresses<think>blocks) or{"thinking": True}for DeepSeek-V3.1 / Granite 3.2. The toolkit does not interpret the dict — it splats the keys into the call, so callers are responsible for using the kwarg name their model expects. Defaults toNone(preserves prior behaviour). Applied on both serverless and classic compute paths.
- Returns:
RemoteInferenceRunwhentrigger_remote=Trueandwait=False.Result
dictwhentrigger_remote=Trueandwait=True.Nonewhentrigger_remote=False.
- Raises:
ValueError – If validation fails (table format, GPU config, etc.).
RemoteInferenceRun#
Serving Endpoints#
The serving_endpoints submodule provides functions for deploying and managing
Databricks serving endpoints for Unity Catalog LLM models. It supports single-version
and multi-version deployments with automatic traffic management using predefined
LLM serving standards.
LLM Serving Standards — five predefined tiers for endpoint configuration:
Standard |
Min tokens/sec |
Max tokens/sec |
Use case |
|---|---|---|---|
|
0 |
1,000 |
Dev / testing (scale-to-zero) |
|
1,000 |
5,000 |
Light production (default) |
|
5,000 |
10,000 |
Standard production |
|
10,000 |
30,000 |
High throughput |
|
30,000 |
70,000 |
Maximum throughput |
from ml_toolkit.functions.llm import (
deploy_model_serving_endpoint,
add_model_version_to_endpoint,
update_endpoint_traffic,
query_serving_endpoint,
)
# 1. Deploy a new endpoint (name inferred as "tiny_llama")
endpoint = deploy_model_serving_endpoint(
model_name="catalog.schema.tiny_llama",
model_versions=[1],
standard="XSMALL",
cost_component_name="ml_research",
wait_for_ready=True,
)
# 2. Add a second version with equal traffic
endpoint = add_model_version_to_endpoint(
version=2,
model_name="catalog.schema.tiny_llama",
standard="XSMALL",
)
# 3. Shift traffic to the new version
update_endpoint_traffic(
model_name="catalog.schema.tiny_llama",
traffic_config={1: 20, 2: 80},
)
# 4. Query a specific version
result_df = query_serving_endpoint(
df=spark.table("prompts"),
endpoint_name="tiny_llama",
version=2,
prompt_column="text",
)
deploy_model_serving_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.interface.deploy_model_serving_endpoint(model_name: str, model_versions: List[int], endpoint_name: str | None = None, endpoint_scaling_size: ml_toolkit.functions.llm.serving_endpoints.constants.EndpointScalingSize = 'SMALL', provisioning_type: ml_toolkit.functions.llm.serving_endpoints.constants.ProvisioningType | None = None, traffic_config: Dict[int, int] | None = None, cost_component_name: str | None = None, tags: Dict[str, str] | None = None, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#
Deploy one or more versions of a Unity Catalog model to a NEW serving endpoint.
This function creates a new serving endpoint. If the endpoint already exists with any served models, it will raise MLOpsToolkitEndpointAlreadyExistsException. Use add_model_version_to_endpoint() to add versions to an existing endpoint.
Automatically determines whether to use provisioned throughput or GPU config by calling the Databricks
get-model-optimization-infoAPI. For provisioned throughput models, token values are rounded to the model’s chunk_size increment. For GPU models, theendpoint_scaling_sizemaps to a workload_type + workload_size combo.:param : Example:
"yd_ml_dpe.fine_tune_staging.qlora_tinyllama_model":type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : Each entry corresponds to a version on the endpoint. Required.Example:
[1]for single version,[1, 2]for two versions. The served model names will be:v1,v2, etc.:type : param model_versions: List of Unity Catalog model version numbers (integers) to deploy. :param : the endpoint name will be automatically inferred from the model name (the last part
after the final dot). Must not already exist with served models, or MLOpsToolkitEndpointAlreadyExistsException will be raised. Example: If model_name is
"catalog.schema.my_model", endpoint_name defaults to"my_model":type : param endpoint_name: Optional name of the serving endpoint to create. If not provided, :param :
"XSMALL","SMALL","MEDIUM","LARGE","XLARGE".For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos: XSMALL=GPU_SMALL/Small, SMALL=GPU_MEDIUM/Small, MEDIUM=GPU_MEDIUM/Medium, LARGE=MULTIGPU_MEDIUM/Medium, XLARGE=GPU_MEDIUM_8/Large.
:type : param endpoint_scaling_size: Endpoint scaling size. Defaults to
"SMALL". Must be one of: :param : the type is auto-detected by calling the Databricks optimization-info API.Must be
"PROVISIONED_THROUGHPUT"or"GPU".:type : param provisioning_type: Optional override for provisioning type. If not provided, :param : to traffic percentages. If None, traffic is distributed equally.
Example:
{1: 70, 2: 30}means version 1 gets 70%, version 2 gets 30%:type : param traffic_config: Optional traffic distribution dict mapping version numbers :param : will attempt to determine from settings or environment, but will raise ValueError
if not found. Always specify this explicitly for production deployments.
:type : param cost_component_name: Required cost component for tracking. If not provided, :param : :type : param tags: Optional additional tags as key-value pairs :param : When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :type : param scale_to_zero_enabled: Whether to enable scale-to-zero for the endpoint (default: True). :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready before returning :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True
- Return type:
Endpoint details dictionary
Examples
Deploy provisioned throughput model (auto-detected)#from ml_toolkit.functions.llm.serving_endpoints import deploy_model_serving_endpoint endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.tiny_llama", model_versions=[1], endpoint_scaling_size="XSMALL", cost_component_name="ml_research", wait_for_ready=True, )
Deploy GPU model with explicit provisioning type#endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.custom_model", model_versions=[1], endpoint_scaling_size="MEDIUM", provisioning_type="GPU", cost_component_name="data_integration_rd", )
Deploy multiple versions with A/B testing#endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.llama_model", model_versions=[1, 2], traffic_config={1: 50, 2: 50}, endpoint_scaling_size="LARGE", cost_component_name="ml_research", )
add_model_version_to_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.interface.add_model_version_to_endpoint(version: int, model_name: str, endpoint_name: str | None = None, endpoint_scaling_size: ml_toolkit.functions.llm.serving_endpoints.constants.EndpointScalingSize = 'SMALL', provisioning_type: ml_toolkit.functions.llm.serving_endpoints.constants.ProvisioningType | None = None, traffic_percentage: int | None = None, redistribute_traffic: bool = True, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#
Add a new model version to an existing serving endpoint.
This function fetches the current endpoint configuration, auto-detects the model name and version from existing versions, and adds the new version with automatic traffic distribution.
Automatically determines whether to use provisioned throughput or GPU config. Validates that the new entity’s provisioning type matches the existing endpoint entities.
:param : The served model name will be:
v{version}(e.g., v1, v2, v3) :type : param version: Integer version identifier (e.g., 1, 2, 3). :param : Used to infer the endpoint name if endpoint_name is not explicitly provided. :type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : the endpoint name will be automatically inferred from model_name. :type : param endpoint_name: Optional name of the existing serving endpoint. If not provided, :param :"XSMALL","SMALL","MEDIUM","LARGE","XLARGE".For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos.
:type : param endpoint_scaling_size: Endpoint scaling size. Defaults to
"SMALL". Must be one of: :param : the type is auto-detected by calling the Databricks optimization-info API.Must be
"PROVISIONED_THROUGHPUT"or"GPU".:type : param provisioning_type: Optional override for provisioning type. If not provided, :param : If None and redistribute_traffic=True, traffic is distributed equally.
If specified, other versions’ traffic will be proportionally reduced.
:type : param traffic_percentage: Percentage of traffic to route to this new version (0-100). :param : If False, the new version gets 0% traffic (must manually call update_endpoint_traffic). :type : param redistribute_traffic: If True, automatically adjust traffic across all versions. :param : When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :type : param scale_to_zero_enabled: Whether to enable scale-to-zero for the new version (default: True). :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready before returning :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True
- Return type:
Updated endpoint details dictionary
Examples
Add version using model_name (recommended)#from ml_toolkit.functions.llm.serving_endpoints import add_model_version_to_endpoint endpoint = add_model_version_to_endpoint( version=3, model_name="catalog.schema.my_model", endpoint_scaling_size="MEDIUM", )
Add version with explicit provisioning type#endpoint = add_model_version_to_endpoint( version=2, endpoint_name="production_endpoint", endpoint_scaling_size="LARGE", provisioning_type="GPU", traffic_percentage=10, )
update_endpoint_traffic#
- ml_toolkit.functions.llm.serving_endpoints.interface.update_endpoint_traffic(traffic_config: Dict[int, int], model_name: str, endpoint_name: str | None = None, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#
Update the traffic distribution for an existing multi-version endpoint.
This is useful for gradual rollouts, canary deployments, or A/B testing scenarios where you want to adjust traffic percentages without redeploying models.
:param : Must sum to 100. Example:
{1: 30, 2: 70}:type : param traffic_config: Dict mapping version numbers to traffic percentages. :param : Used to infer the endpoint name if endpoint_name is not explicitly provided. :type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : the endpoint name will be automatically inferred from model_name. :type : param endpoint_name: Optional name of the serving endpoint. If not provided, :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready after update :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True- Return type:
Updated endpoint details dictionary
Examples
Gradually increase traffic to new version using model_name## Start: 90% version 1, 10% version 2 update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 90, 2: 10}, ) # After monitoring: 50% split update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 50, 2: 50}, ) # Final: 100% version 2 update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 0, 2: 100}, )
Using explicit endpoint name#update_endpoint_traffic( endpoint_name="my-endpoint", traffic_config={1: 50, 2: 50}, )
get_serving_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.interface.get_serving_endpoint(model_name: str, endpoint_name: str | None = None) Dict[source]#
Get details of a serving endpoint.
:param : Required parameter used to infer endpoint_name if not provided. :type : param model_name: Unity Catalog model name (catalog.schema.model_name). :param : automatically inferred from model_name. :type : param endpoint_name: Name of the endpoint. Optional - if not provided,
- Return type:
Endpoint details dictionary containing configuration, state, and metadata
Examples
from ml_toolkit.functions.llm.serving_endpoints import get_serving_endpoint # Endpoint name automatically inferred from model name endpoint = get_serving_endpoint(model_name="catalog.schema.tiny_llama") # Endpoint name inferred as "tiny_llama" print(f"State: {endpoint['state']['ready']}") print(f"Served entities: {endpoint['config']['served_entities']}") # Or specify explicit endpoint name endpoint = get_serving_endpoint( model_name="catalog.schema.tiny_llama", endpoint_name="my-custom-endpoint" )
list_serving_endpoints#
- ml_toolkit.functions.llm.serving_endpoints.interface.list_serving_endpoints(filter_tags: Dict[str, str] | None = None) List[Dict][source]#
List all serving endpoints in the workspace.
:param : ALL specified tags matching will be returned :type : param filter_tags: Optional dict of tags to filter by. Only endpoints with
- Return type:
List of endpoint dictionaries
Examples
from ml_toolkit.functions.llm.serving_endpoints import list_serving_endpoints # List all endpoints endpoints = list_serving_endpoints() for ep in endpoints: print(f"{ep['name']}: {ep['state']['ready']}") # Filter by cost component my_team_endpoints = list_serving_endpoints( filter_tags={"cost_component_name": "ml_research"} ) # Filter by multiple tags prod_endpoints = list_serving_endpoints( filter_tags={"env": "production", "team": "ai-platform"} )
query_serving_endpoint#
- ml_toolkit.functions.llm.serving_endpoints.helpers.query.query_serving_endpoint(df: pyspark.sql.DataFrame, endpoint_name: str, version: int | None = None, prompt_column: str = 'prompt', prompt_template: str = None, max_tokens: int = 512, temperature: float = 0.7, output_column: str = 'output') pyspark.sql.DataFrame[source]#
Query a specific model version in a multi-version serving endpoint.
This convenience function prepares the DataFrame and calls the UDTF, automatically routing to the specified version (v1, v2, v3, etc.).
- Parameters:
df – Input DataFrame with a prompt column
endpoint_name – Name of the serving endpoint
version – Integer version identifier (e.g., 1, 2, 3). Defaults to None (latest version). Routes to served model “v{version}” (e.g., v1, v2, v3). If None, automatically resolves to the highest available version.
prompt_column – Name of the column containing prompts (default: “prompt”). Ignored if prompt_template is specified.
prompt_template – Optional template string with <<column>> placeholders. Use double angle brackets to reference DataFrame columns. Examples: “Classify: <<text>>”, “Extract entity: <<merchant_name>>”
max_tokens – Maximum output tokens (default: 512)
temperature – Sampling temperature (default: 0.7)
output_column – Name of the output column (default: “output”)
- Returns:
{output_column}, start_timestamp, end_timestamp, error_message
- Return type:
DataFrame with columns
Example
>>> from ml_toolkit.functions.llm.serving_endpoints.helpers.query import query_serving_endpoint >>> from pyspark.sql import SparkSession >>> >>> spark = SparkSession.builder.getOrCreate() >>> df = spark.table("input_table") >>> >>> # Query latest version (default) >>> result_df = query_serving_endpoint( ... df=df, ... endpoint_name="my_endpoint", ... prompt_column="text", ... max_tokens=1000, ... temperature=0.5 ... ) >>> >>> # Query specific version >>> result_df = query_serving_endpoint( ... df=df, ... endpoint_name="my_endpoint", ... version=1, ... prompt_template="Classify this <<category>> text: <<content>>", ... max_tokens=1000, ... temperature=0.5 ... )
Model Registration#
The model registration functions provide a complete workflow for logging and registering models into Databricks Unity Catalog via MLflow. Two dedicated paths are available:
HuggingFace LLMs — use
register_LLM_modelto log a model + tokenizer viamlflow.transformers.Generic pyfunc models — use
register_modelto log any customPythonModelviamlflow.pyfunc.
Both functions support tagging, descriptions, artifact copying, prompt attachment, and optional auto-deployment to an alias on registration.
from ml_toolkit import register_LLM_model
model_info = register_LLM_model(
model_name="catalog.schema.llama_3_1_8b",
model=model,
tokenizer=tokenizer,
torch_dtype="bfloat16",
model_config={"max_new_tokens": 512, "temperature": 0.1},
params={"learning_rate": 2e-4, "epochs": 3},
metrics={"train_loss": 0.85},
model_tags={"team": "nlp"},
target_alias="prod",
)
print(f"Registered version {model_info.version}")
from ml_toolkit import register_model
model_info = register_model(
model_name="catalog.schema.my_classifier",
custom_model=MyCustomModel(),
source_artifacts_dir="/dbfs/tmp/model_weights",
pip_requirements=["scikit-learn==1.4.0"],
target_alias="prod",
)
register_LLM_model#
register_model#
deploy_model#
log_artifacts#
log_artifact#
Model Loading#
The model loading functions handle fetching registered models from Unity Catalog or HuggingFace Hub, with built-in retry logic for transient network or cluster errors.
from ml_toolkit import load_model_with_retry, load_auto_tokenizer_with_retry, get_model_info
# Load by UC model name (latest version by default)
model, tokenizer = load_model_with_retry(
model_checkpoint="catalog.schema.llama_3_1_8b",
return_type="components",
)
# Inspect model version metadata
info = get_model_info("catalog.schema.llama_3_1_8b")
info.display()