llm module#

The llm module of the ml_toolkit contains all the functions that enable LLM usage within Databricks notebooks. We expose the following functions:

  • run_fine_tuning: fine-tune an LLM using QLoRA with Optuna hyperparameter optimization, with local or remote GPU execution.

  • run_llm_batch: performs row-level LLM querying in our data, outputting their response to a new column.

  • estimate_token_usage: estimates token usage of the run_llm_batch function.

  • query: gives you the ability to directly query a model and give it context and tools to perform an action.

  • run_vector_search: performs vector search (nearest or hybrid) over a Databricks vector search index.

  • run_inference: runs vLLM batch inference on Databricks Serverless GPU clusters using self-hosted Unity Catalog models.

PySpark UDTFs for at-scale inference:

  • setup_udtf: register any UDTF by passing the appropriate config (agent or multi-model).

  • run_agent_batch: one-call batch execution of an agent UDTF with table writes.

Prompt management functions for the MLflow Prompt Registry:

  • create_or_fetch_prompt: create or fetch prompts with automatic deduplication.

  • load_prompt: load a prompt by name, version, or URI/alias.

  • search_prompts: search for prompts in the registry.

  • set_prompt_alias: set aliases (e.g., “production”) for prompt versions.

  • delete_prompt_version: delete a specific version of a prompt.

  • delete_prompt: delete a prompt from the registry.

Serving endpoint management:

  • deploy_model_serving_endpoint: deploy Unity Catalog model versions to a new serving endpoint.

  • add_model_version_to_endpoint: add a model version to an existing endpoint with automatic traffic redistribution.

  • update_endpoint_traffic: update traffic distribution for A/B testing and canary deployments.

  • get_serving_endpoint: get endpoint details.

  • list_serving_endpoints: list all endpoints with optional tag filtering.

  • query_serving_endpoint: query a specific version of a serving endpoint.

Model registration functions for Unity Catalog:

  • register_LLM_model: register a HuggingFace LLM (model + tokenizer) to Unity Catalog via mlflow.transformers.

  • register_model: register a generic pyfunc model to Unity Catalog via mlflow.pyfunc.

  • deploy_model: deploy a registered model version to a specified alias (e.g. "prod").

  • log_artifacts: copy a directory of artifacts to the model version’s UC Volume path.

  • log_artifact: copy a single artifact to the model version’s UC Volume path.

Model loading and inspection functions:

  • get_model_info: fetch metadata for a registered model version (tags, alias, volume path).

  • load_model_with_retry: load a model from UC or HuggingFace Hub with automatic retry logic.

  • load_auto_tokenizer_with_retry: load a HuggingFace tokenizer with automatic retry logic.

Attention

These functions are in an experimental phase and are subject to change. If you have any feedback, please submit it via our Jira Form.

Attention

You must pass a cost_component_name to the functions that call the LLMs, otherwise they will raise exceptions.

Quota Controls#

We have set strict usage controls, or quotas, to limit usage. This means you will only be able to run a set limit of tokens without having an exception being raised. LLMs can get very expensive quickly, so these limits are set to avoid high usage.

If you have a valid usecase that solves a business problem and needs a higher quota to run that with your data, please submit a ticket. The approval process for this involves getting your manager to approve your usecase and requested budget. When submitting this, always include the output from estimate_token_usage.

Prompts#

The prompt functions provide a complete interface for managing prompts in the MLflow Prompt Registry. Prompts are versioned, support aliases (e.g., “production”, “staging”), and use content-based deduplication.

Quick start#
from ml_toolkit.functions.llm import create_or_fetch_prompt, load_prompt, set_prompt_alias

# Create or fetch a prompt (deduplicates automatically)
result = create_or_fetch_prompt(
    name="mycatalog.myschema.qa_prompt",
    template="Answer this question: {{question}}",
    commit_message="Initial QA prompt"
)

# Set an alias for easy reference
set_prompt_alias("mycatalog.myschema.qa_prompt", alias="production", version=result["version"])

# Load by alias
prompt = load_prompt("prompts:/mycatalog.myschema.qa_prompt@production")

create_or_fetch_prompt#

load_prompt#

search_prompts#

set_prompt_alias#

delete_prompt_version#

delete_prompt#

UDTFs (User-Defined Table Functions)#

The udtf submodule provides PySpark UDTFs for running LLM inference at scale inside Spark SQL queries. All UDTFs are registered through a single function, setup_udtf, which dispatches based on the config type:

  • AgentUDTFConfig — Runs an OpenAI Agents SDK agent (with tools, Jinja2 instructions) per row.

  • MultiModelUDTFConfig — Sends each prompt to multiple models in parallel and yields one row per model.

  • InferenceConfig — Runs a single-model chat completion per row with prompt templating.

  • ServingEndpointConfig — Queries a specific served model in a multi-version serving endpoint via REST API.

setup_udtf#

ml_toolkit.functions.llm.udtf.interface.setup_udtf(config: ml_toolkit.functions.llm.udtf.config.UDTFConfig, name: str = None, prompt_column: str = 'prompt') type[source]#

Register a PySpark UDTF based on the config type.

Dispatches to the appropriate UDTF implementation based on the config class. After registration, the UDTF can be called from Spark SQL.

Supported config types:

  • AgentUDTFConfig — Runs an OpenAI Agents SDK agent per row.

  • MultiModelUDTFConfig — Sends each prompt to multiple models in parallel, yielding one row per (input, model).

Parameters:
  • config – A UDTFConfig subclass instance.

  • name – UDTF name for SQL registration. Defaults vary by type: run_{config.name} for agents, "multi_model_inference" for multi-model.

  • prompt_column – Column name containing the user prompt. Defaults to "prompt".

Returns:

The UDTF class (also registered in SparkSession).

Raises:

TypeError – If config is not a supported subclass.

Examples

from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf

config = AgentUDTFConfig(
    name="my_agent",
    instructions="You are a helpful assistant.",
    model="openai/databricks-claude-3-7-sonnet",
)
setup_udtf(config)
spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))")
from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf

config = MultiModelUDTFConfig(max_workers=20)
setup_udtf(config)
spark.sql("""
    SELECT * FROM multi_model_inference(
        TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM prompts)
    )
""")

Agent UDTF#

Register an agent as a PySpark UDTF, then call it from SQL or use the high-level batch API.

Agent UDTF quick start#
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig, setup_udtf, run_agent_batch

config = AgentUDTFConfig(
    name="my_agent",
    instructions="You are a helpful assistant.",
    model="openai/databricks-claude-3-7-sonnet",
    tools=[my_tool_function],
)

setup_udtf(config)
results = spark.sql("SELECT * FROM run_my_agent(TABLE(SELECT * FROM prompts))")

# Or use the high-level batch API with table writes:
result = run_agent_batch(
    data_source="catalog.schema.prompts",
    agent_config=config,
    cost_component_name="my_team",
)

AgentUDTFConfig#

class ml_toolkit.functions.llm.udtf.config.AgentUDTFConfig[source]#

Configuration for an OpenAI Agents SDK agent UDTF.

Extends UDTFConfig with agent-specific parameters: model, instructions, tools, and LiteLLM routing.

Parameters:
  • name – Unique name for this agent (used in UDTF registration and telemetry).

  • instructions – System instructions for the agent. Supports Jinja2 templates with variables from the input row (e.g., {{database}}).

  • model – LiteLLM model identifier (e.g., "openai/gpt-5").

  • tools – List of Python callables to register as agent tools. Each function should have type annotations and a docstring (used to auto-generate the tool schema).

  • include_sql_tool – Whether to include the built-in read-only SQL execution tool.

  • include_web_search_tool – Whether to include the built-in web search tool.

  • max_turns – Maximum number of agent turns (tool call rounds).

  • warehouse_id – Databricks SQL warehouse ID for the SQL tool. Resolved from DATABRICKS_WAREHOUSE_ID env var or Databricks secrets if None.

  • litellm_base_url – LiteLLM proxy base URL. Resolved from Databricks secrets if None.

  • litellm_api_key – LiteLLM proxy API key. Resolved from Databricks secrets if None.

  • extra_agent_kwargs – Additional kwargs passed to the agents.Agent constructor.

Examples

Basic configuration with custom tools.#
from ml_toolkit.functions.llm.udtf import AgentUDTFConfig

def my_lookup(query: str) -> str:
    """Look up data in a custom API."""
    return "result"

config = AgentUDTFConfig(
    name="my_agent",
    instructions="You are a research assistant.",
    model="openai/gpt-5",
    tools=[my_lookup],  # auto-wrapped with agents.function_tool()
)
Configuration with built-in tools and Jinja2 instructions.#
config = AgentUDTFConfig(
    name="corp_insights",
    instructions="You analyze data for the {{database}} team.",
    model="openai/databricks-claude-3-7-sonnet",
    include_sql_tool=True,
    include_web_search_tool=True,
)

Note

Tools are plain Python callables with type annotations and docstrings. They are automatically wrapped with agents.function_tool() at agent build time. The OpenAI Agents SDK uses the function signature and docstring to generate the tool schema that the LLM sees.

run_agent_batch#

ml_toolkit.functions.llm.udtf.interface.run_agent_batch(data_source: pyspark.sql.DataFrame | str, agent_config: ml_toolkit.functions.llm.udtf.config.AgentUDTFConfig, prompt_column: str = 'prompt', output_table_name: str | None = None, dry_run: bool = None, table_operation: Literal['overwrite', 'append'] = 'overwrite', cost_component_name: str = None) dict[source]#

Run an agent against every row in a DataFrame or table.

This is the high-level batch API (analogous to run_llm_batch()). It registers a temporary UDTF, runs it against all input rows, and writes results to a Delta table or returns them as a DataFrame.

There are two operation modes:

  1. dry_run=True: Returns results as a DataFrame without writing. Only works with fewer than 1,000 rows.

  2. dry_run=False: Writes results to output_table_name. Can overwrite or append (set via table_operation).

Attention

You must pass a cost_component_name, otherwise this function will raise an exception.

Parameters:
  • data_source – Input DataFrame or fully-qualified Delta table name.

  • agent_configAgentUDTFConfig defining the agent.

  • prompt_column – Column name containing the user prompt. Defaults to "prompt".

  • output_table_name – Fully-qualified output Delta table name (required for batch mode).

  • dry_run – If True, returns results as DataFrame. If None, auto-detected based on row count.

  • table_operation"overwrite" or "append" for the output table.

  • cost_component_name – Required cost component for telemetry.

Returns:

Dict with source_uuid, model, df, output_table_name, cost_component_name, and error_message.

Raises:
  • ValueError – If cost_component_name is missing or output_table_name missing in batch mode.

  • MLOpsToolkitTooManyRowsForInteractiveUsage – If dry_run=True with too many rows.

Examples

from ml_toolkit.functions.llm import AgentConfig, run_agent_batch

config = AgentConfig(
    name="corp_insights",
    instructions="You analyze corporate data for the {{database}} team.",
    model="openai/databricks-claude-3-7-sonnet",
    include_sql_tool=True,
    include_web_search_tool=True,
)

result = run_agent_batch(
    data_source="catalog.schema.input_prompts",
    agent_config=config,
    prompt_column="prompt",
    output_table_name="catalog.schema.agent_output",
    cost_component_name="my_team",
)
display(result["df"])

Multi-Model Inference UDTF#

Send each input row to multiple models in parallel and get one output row per (input, model). Useful for model comparison, consensus checks, and A/B testing.

Multi-model UDTF quick start#
from ml_toolkit.functions.llm.udtf import MultiModelUDTFConfig, setup_udtf

config = MultiModelUDTFConfig(max_workers=20, max_retries=3)
setup_udtf(config)

# Input must have: prompt, models (comma-separated), optional system_prompt
results = spark.sql("""
    SELECT * FROM multi_model_inference(
        TABLE(
            SELECT
                prompt,
                'openai/gpt-4o,openai/databricks-claude-3-7-sonnet' AS models
            FROM my_prompts
        )
    )
""")

The UDTF yields one row per (input, model) with this schema:

Column

Type

Description

model

string

The model that produced this output

output

string

Cleaned response (markdown code fences stripped)

parameters

string

JSON of all non-prompt/models columns from the input row

error

string

Error details if the call failed (empty on success)

raw_output

string

Unprocessed model response

MultiModelUDTFConfig#

class ml_toolkit.functions.llm.udtf.config.MultiModelUDTFConfig[source]#

Configuration for a multi-model inference UDTF.

Extends UDTFConfig with LiteLLM credentials and optional structured-output schema. The UDTF sends each input row to multiple models in parallel and yields one output row per (input, model).

Parameters:
  • litellm_base_url – LiteLLM proxy base URL. Resolved from LITELLM_PROXY_BASE_URL env var or Databricks secrets if None.

  • litellm_api_key – LiteLLM proxy API key. Resolved from LITELLM_API_KEY env var or Databricks secrets if None.

  • response_schema – Optional JSON-schema dict for structured output. When set, passed as response_format to the chat completions API.

Examples

from ml_toolkit.functions.llm.udtf import (
    MultiModelUDTFConfig,
    setup_multi_model_udtf,
)

config = MultiModelUDTFConfig(max_workers=20, max_retries=3)
setup_multi_model_udtf(config)

# Input must have: prompt, models (comma-separated), optional system_prompt
spark.sql("""
    SELECT * FROM multi_model_inference(
        TABLE(SELECT prompt, 'gpt-4o,claude-3' AS models FROM my_table)
    )
""")

Inference UDTF#

Run a single-model chat completion against every row. Supports <<column>> prompt templates, system prompts, and any LiteLLM-compatible model.

Inference UDTF quick start#
from ml_toolkit.functions.llm.udtf import InferenceConfig, setup_udtf

config = InferenceConfig(
    model="openai/gpt-5",
    prompt_template="Tag the following vendor name: <<input>>. Options: <<candidates>>.",
    system_prompt="You are a precise tagging assistant.",
    input_column="input",
    additional_context_columns=["candidates"],
    max_output_tokens=256,
    cost_component_name="my_team",
)

setup_udtf(config, name="vendor_tagger")
results = spark.sql("""
    SELECT * FROM vendor_tagger(
        TABLE(SELECT * FROM catalog.schema.eval_data)
    )
""")

InferenceConfig#

class ml_toolkit.functions.llm.udtf.config.InferenceConfig[source]#

Configuration for inference UDTF.

Extends UDTFConfig with model, prompt, and LiteLLM credentials for running chat completion inference.

Parameters:
  • model – LiteLLM model identifier (e.g., "gpt-5").

  • prompt_template – Jinja2 template using <<variable>> syntax. If None, the raw input column value is used as the prompt.

  • system_prompt – Optional system prompt for the chat completion.

  • max_output_tokens – Maximum output tokens for model response.

  • input_column – Column name containing the input text.

  • additional_context_columns – Additional columns available in template context.

  • litellm_base_url – LiteLLM proxy base URL. Resolved from Databricks secrets if None.

  • litellm_api_key – LiteLLM proxy API key. Resolved from Databricks secrets if None.

  • litellm_model_kwargs – Additional kwargs for the chat completions API.

Serving Endpoint UDTF#

Query a specific served model version in a multi-version serving endpoint using the Databricks REST API. This is the low-level UDTF behind query_serving_endpoint() (see Serving Endpoints).

The endpoint name and served model can be set in the config or provided per-row, enabling dynamic routing across endpoints and versions.

Serving endpoint UDTF quick start#
from ml_toolkit.functions.llm.udtf import ServingEndpointConfig, setup_serving_endpoint_udtf

config = ServingEndpointConfig(
    endpoint_name="my-endpoint",
    served_model_name="v1",
)
setup_serving_endpoint_udtf(config, name="query_my_endpoint")

results = spark.sql("""
    SELECT * FROM query_my_endpoint(
        TABLE(SELECT prompt FROM catalog.schema.prompts)
    )
""")
Per-row dynamic routing#
# DataFrame provides endpoint_name and served_model_name per row
config = ServingEndpointConfig()  # no defaults needed
setup_serving_endpoint_udtf(config)

results = spark.sql("""
    SELECT * FROM udtf_serving_endpoint(
        TABLE(
            SELECT
                prompt,
                endpoint_name,
                served_model_name,
                512 AS max_tokens,
                0.7 AS temperature
            FROM catalog.schema.routed_prompts
        )
    )
""")

ServingEndpointConfig#

class ml_toolkit.functions.llm.udtf.config.ServingEndpointConfig[source]#

Configuration for serving endpoint UDTF.

Extends UDTFConfig for querying specific served models in multi-version serving endpoints using the Databricks REST API.

Parameters:
  • endpoint_name – Name of the serving endpoint (optional, can be provided per-row).

  • served_model_name – Served model name to route to (optional, can be provided per-row). Format: "{name}-{endpoint_name}" (e.g., "v1-my-endpoint").

  • timeout_sec – Timeout in seconds for serving endpoint requests. Default is 900 seconds (15 minutes) to accommodate cold starts from scale-to-zero.

Examples

Configuration with default endpoint routing.#
from ml_toolkit.functions.llm.udtf import ServingEndpointConfig

config = ServingEndpointConfig(
    endpoint_name="my-endpoint",
    served_model_name="v1-my-endpoint",
)

Note

Both endpoint_name and served_model_name can be provided per-row in the DataFrame, allowing dynamic routing across multiple endpoints and model versions.

setup_serving_endpoint_udtf#

ml_toolkit.functions.llm.udtf.serving_endpoint.setup_serving_endpoint_udtf(config: ml_toolkit.functions.llm.udtf.config.ServingEndpointConfig, name: str = 'udtf_serving_endpoint', prompt_column: str = 'prompt')[source]#

Setup and register the serving endpoint UDTF.

Uses the generic setup_udtf from base.py for consistent behavior.

Parameters:
  • config – ServingEndpointConfig instance

  • name – UDTF name for SQL registration

  • prompt_column – Column name containing the user prompt

Returns:

The registered UDTF class

Fine-Tuning#

The run_fine_tuning function provides end-to-end LLM fine-tuning using QLoRA with Optuna hyperparameter optimization. It loads training data from a Unity Catalog table, trains a QLoRA adapter, registers the merged model to Unity Catalog, and returns training metrics. Supports both local GPU execution and remote Databricks serverless/classic GPU clusters.

Quick start — minimal fine-tuning with defaults#
from ml_toolkit.functions.llm import run_fine_tuning

REQUIREMENTS = [
    "transformers==4.57.6", "datasets==4.5.0", "accelerate==1.12.0",
    "peft==0.18.1", "bitsandbytes==0.49.1", "safetensors==0.7.0",
    "threadpoolctl==3.6.0", "optuna==4.7.0",
]

result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_fine_tuned_model",
    base_model="qwen2_5_3b_instruct",
    requirements=REQUIREMENTS,
)
print(f"Training loss: {result['training_loss']:.4f}")
print(f"Model version: {result['model_version']}")
Fully customized — all config options#
from ml_toolkit.functions.llm import run_fine_tuning
from ml_toolkit.ml.llm.fine_tuning.config import (
    DataConfig, LoRAConfig, TrainingConfig,
)

result = run_fine_tuning(
    training_table_name="catalog.schema.vendor_data",
    model_name="catalog.schema.vendor_tagger",
    base_model="qwen2_5_7b_instruct",
    requirements=REQUIREMENTS + ["jinja2"],
    adapter_config=LoRAConfig(
        r=32,                        # LoRA rank (default: 16)
        alpha=64,                    # scaling factor (default: 2*r)
        dropout=0.1,                 # dropout rate (default: 0.05)
        target_modules="all-linear", # (default: "all-linear")
    ),
    training_config=TrainingConfig(
        max_steps=300,               # steps per trial (default: 100)
        n_trials=10,                 # Optuna trials (default: 3)
        max_seq_length=1024,         # max tokens (default: 256)
        seed=42,                     # (default: 9)
        eval_steps=10,               # (default: 1)
        logging_steps=5,             # (default: 1)
        compute_dtype="bf16",        # (default: "auto")
        tier="large",                # override auto-detected tier
    ),
    data_config=DataConfig(
        text_column="description",   # (default: "text")
        label_column="merchant",     # (default: "label")
        id_column="row_uuid",        # (default: "id")
        max_train_samples=5000,      # (default: 100)
        train_test_val_split=(0.7, 0.15, 0.15),  # (default: (0.6, 0.2, 0.2))
        data_source_filter="bank_of_america",     # (default: None)
        data_source_column="data_source",         # (default: None)
        prompt_registry_name="catalog.schema.my_prompt",
        prompt_alias="production",
    ),
    env_vars={"TOKENIZERS_PARALLELISM": "false"},
    # --- Remote execution ---
    trigger_remote=True,             # submit as remote job (default: False)
    cluster_type="serverless_gpu",   # "serverless_gpu" or "classic_gpu" (default: "serverless_gpu")
    gpu_type="h100",                 # "a10" or "h100" (default: "a10")
    num_gpus=1,                      # number of GPUs (default: 1)
    wait=True,                       # False returns RemoteJobRun handle (default: True)
)
Remote execution on serverless GPU#
result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_model",
    base_model="qwen2_5_3b_instruct",
    requirements=REQUIREMENTS,
    trigger_remote=True,             # submit as remote job
    cluster_type="serverless_gpu",   # (default: "serverless_gpu")
    gpu_type="a10",                  # "a10" or "h100" (default: "a10")
    num_gpus=1,                      # (default: 1)
    wait=True,                       # False returns RemoteJobRun handle
)

Model signature — by default no MLflow signature is set on the registered model. You can provide custom input/output schemas as PySpark types or DDL strings. This controls how Databricks Model Serving validates request/response payloads.

Custom signature with PySpark types#
from pyspark.sql.types import StructType, StructField, StringType, ArrayType

result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_ner_model",
    base_model="qwen2_5_3b_instruct",
    requirements=REQUIREMENTS,
    input_schema=StructType([
        StructField("text", StringType(), nullable=False),
        StructField("context", StringType(), nullable=True),
    ]),
    output_schema=ArrayType(StringType()),
)
Custom signature with DDL strings#
result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_model",
    base_model="qwen2_5_3b_instruct",
    requirements=REQUIREMENTS,
    input_schema="struct<prompt:string>",
    output_schema="array<string>",
)

Available base models — pass any key as base_model, or use a direct HuggingFace model ID (set tier in TrainingConfig when using a custom HF ID):

Key

HuggingFace ID

Tier

qwen3_30b_a3b

Qwen/Qwen3-30B-A3B

xlarge

qwen3_14b

Qwen/Qwen3-14B

large

qwen3_8b

Qwen/Qwen3-8B

large

qwen3_4b

Qwen/Qwen3-4B

medium

qwen3_4b_instruct_2507

Qwen/Qwen3-4B-Instruct-2507

medium

qwen2_5_7b_instruct

Qwen/Qwen2.5-7B-Instruct

large

qwen2_5_3b_instruct

Qwen/Qwen2.5-3B-Instruct

medium

qwen2_5_1_5b_instruct

Qwen/Qwen2.5-1.5B-Instruct

small

qwen2_5_0_5b_instruct

Qwen/Qwen2.5-0.5B-Instruct

xsmall

smollm2_360m_instruct

HuggingFaceTB/SmolLM2-360M-Instruct

xsmall

smollm2_135m_instruct

HuggingFaceTB/SmolLM2-135M-Instruct

xsmall

tinyllama_1_1b_chat

TinyLlama/TinyLlama-1.1B-Chat-v1.0

small

deepseek_distill_2b

deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

small

Using any HuggingFace model — you are not limited to the catalog above. Pass any HuggingFace model ID directly as base_model (e.g., "meta-llama/Llama-3.1-8B-Instruct"). When using a custom model ID, set tier in TrainingConfig to control batch sizes and gradient accumulation (defaults to "medium" if not set):

result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_model",
    base_model="meta-llama/Llama-3.1-8B-Instruct",  # any HuggingFace model ID
    requirements=REQUIREMENTS,
    training_config=TrainingConfig(tier="large"),     # set tier for resource allocation
)

Tier controls batch sizes and gradient accumulation in the Optuna HP search space:

Tier

Batch sizes

Gradient accumulation

xsmall

[32, 64, 128]

[1, 2, 4]

small

[16, 32, 64]

[1, 2, 4, 8]

medium

[4, 8, 16]

[2, 4, 8]

large

[1, 2, 4]

[8, 16, 32]

xlarge

[1, 2]

[16, 32, 64]

Prompt template options — three ways to provide a prompt:

  1. str.format (default) — uses {text} placeholder, only the text_column value is available:

    DataConfig(prompt_template="Extract entities from: {text}\n\nAnswer:\n")
    
  2. Jinja2 — uses <<column>> delimiters, all DataFrame columns are available. Auto-detected when template contains <<, {%, or {#. Add "jinja2" to requirements.

    DataConfig(prompt_template="Transaction: <<text>>\nCandidates: <<candidates>>\nAnswer:\n")
    
  3. MLflow Prompt Registry — loaded at runtime, {{variable}} is auto-converted to <<variable>>. Cannot be combined with prompt_template.

    DataConfig(
        prompt_registry_name="catalog.schema.my_prompt",
        prompt_alias="production",   # or prompt_version=1
    )
    

run_fine_tuning#

LoRAConfig#

TrainingConfig#

DataConfig#

Functions#

run_llm_batch#

estimate_token_usage#

query#

LLMResponse#

class ml_toolkit.functions.llm.query.LLMResponse[source]#

Class that defines the LLM response of the query function. It exposes the following attributes

  • .text: returns the text response

  • response: returns the raw LLM response class (openai.ChatCompletion)

And also the following methods:

  • .json(): tries to parse the output into a python dict if possible

__init__(raw_response: openai.ChatCompletion)[source]#
__new__(*args, **kwargs)[source]#

Only added here to avoid having this method appear on the docs. :meta private:

GPU Inference#

The run_inference function runs vLLM batch inference on Databricks Serverless GPU clusters using self-hosted Unity Catalog models. It wraps run_serverless_inference and submits the work as a remote job via run_remote(), so you can launch inference from any notebook — no GPU cluster required on the caller side.

The function validates inputs (table names, GPU config), serializes parameters, and submits a one-time Databricks workflow that provisions GPUs, runs vLLM inference via Ray Data, and writes results to a Delta table.

Fire-and-forget — submit and get a tracking URL#
from ml_toolkit.functions.llm import run_inference

remote_run = run_inference(
    "catalog.schema.input_table",
    "catalog.schema.my_model",
    "catalog.schema.output",
    uc_volumes_path="/Volumes/catalog/schema/vol",
)

# Track in the Databricks UI
print(remote_run.databricks_url)
print(remote_run.job_run_id)

# When ready, block until the job finishes
result = remote_run.get_result()
Block until inference completes#
result = run_inference(
    "catalog.schema.input_table",
    "catalog.schema.my_model",
    "catalog.schema.output",
    uc_volumes_path="/Volumes/catalog/schema/vol",
    wait=True,
)
Run on the current GPU cluster (no remote submission)#
run_inference(
    "catalog.schema.input_table",
    "catalog.schema.my_model",
    "catalog.schema.output",
    uc_volumes_path="/Volumes/catalog/schema/vol",
    trigger_remote=False,
)
H100 with custom generation parameters#
remote_run = run_inference(
    "catalog.schema.input_table",
    "catalog.schema.my_model",
    "catalog.schema.output",
    uc_volumes_path="/Volumes/catalog/schema/vol",
    num_gpus=8,
    gpu_type="h100",
    temperature=0.0,
    max_new_tokens=1024,
    tensor_parallel_size=8,
)

GPU types and constraints:

GPU type

num_gpus

Notes

a10

Any (default: 4)

Good for models up to ~8B parameters. Default tensor_parallel_size=1.

h100

Multiple of 8

Required for large models (>8B). num_gpus must be a multiple of 8.

Auto-loading behaviour: When the model was registered with model_config and prompt_uris via register_LLM_model, generation parameters (temperature, max_new_tokens, max_model_len, etc.) and prompt templates are loaded automatically. Explicit parameters always take precedence.

run_inference#

ml_toolkit.functions.llm.inference.interface.run_inference(input_source: str | pyspark.sql.DataFrame, model_name: str, output_table: str, *, prompt_column: str = 'prompt', model_version: int | Literal['latest'] | None = None, model_alias: str | None = None, system_prompt: str | None = None, max_new_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, top_k: int | None = None, min_p: float | None = None, presence_penalty: float | None = 0.0, repetition_penalty: float | None = 1.0, max_model_len: int | None = None, num_gpus: int = 4, gpu_type: str = 'a10', concurrency: int | None = None, batch_size: int | None = None, tensor_parallel_size: int = 1, gpu_memory_utilization: float | None = 0.85, max_num_seqs: int | None = None, max_num_batched_tokens: int | None = None, kv_cache_dtype: str | None = None, dtype: str | None = None, quantization: str | None = None, uc_volumes_path: str | None = None, auto_build_prompt: bool = True, output_column: str = 'model_output', pk_column: str | None = None, table_operation: Literal['overwrite', 'append'] | None = 'overwrite', ray_object_store_memory: float = 0.5, http_pool_maxsize: int = 64, trigger_remote: bool = True, wait: bool = False, compute_type: Literal['serverless', 'classic', 'anyscale'] = 'serverless', capacity_reservation_id: str | None = 'auto', availability_zone: str | None = 'auto', ray_kwargs: dict | None = None, env_vars: dict[str, str] | None = None, output_schema: object | None = None, chat_template_kwargs: dict | None = None, output_storage_location: str | None = None, output_bucket: str | None = None, checkpoint_path: str | None = None, anyscale_kwargs: AnyscaleKwargs | dict | None = None) ml_toolkit.functions.llm.inference.function.RemoteInferenceRun | dict | None[source]#

Run vLLM batch inference on Databricks GPU compute.

Supports two compute backends via compute_type:

  • "serverless" (default) — Databricks Serverless GPU. Uses A10 or H100 accelerators. Requires uc_volumes_path for data staging.

  • "classic" — Databricks classic cluster with p5.48xlarge (8x H100) or p6-b200.48xlarge (8x B200) nodes. Uses Ray-on-Spark for distributed inference. GPUs in multiples of 8. The driver contributes 8 GPUs; the number of additional worker nodes is auto-computed as max(0, (num_gpus - 8) // 8). Supports AWS Capacity Blocks.

When trigger_remote=True (default), the function validates inputs, serializes parameters, and submits a remote job via run_remote().

When trigger_remote=False, inference runs directly on the current cluster (must already have GPUs available).

Parameters:
  • input_source – Fully qualified UC table name (catalog.schema.table) or DataFrame.

  • model_name – UC model name (catalog.schema.model).

  • output_table – Fully qualified output table name.

  • prompt_column – Column containing prompt text. Defaults to "prompt".

  • model_version – Specific model version (int) or "latest". Mutually exclusive with model_alias. None = production alias.

  • model_alias – UC alias (e.g. "champion"). Mutually exclusive with model_version.

  • system_prompt – System instruction for chat completions.

  • max_new_tokens – Maximum tokens to generate per prompt.

  • temperature – Sampling temperature (0.0 = greedy).

  • top_p – Nucleus sampling threshold.

  • top_k – Controls the number of top tokens to consider. Set to 0 (or -1) to consider all tokens.

  • min_p – Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1].

  • presence_penalty – Penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.

  • repetition_penalty – Penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.

  • max_model_len – Maximum sequence length.

  • num_gpus – Number of GPUs. Defaults to 4. For compute_type="classic", must be a multiple of 8 (H100 only) and counts the total across driver + workers. num_gpus=8 runs driver-only; num_gpus=16 runs driver + 1 worker node.

  • gpu_type – GPU type — "a10", "h100", or "b200". Defaults to "a10". For compute_type="classic", must be "h100" or "b200".

  • concurrency – Number of vLLM data-parallel replicas.

  • batch_size – Rows per inference batch.

  • tensor_parallel_size – GPUs per vLLM instance.

  • gpu_memory_utilization – Fraction of GPU memory to use for vLLM.

  • max_num_seqs – Maximum concurrent sequences per vLLM engine.

  • max_num_batched_tokens – Maximum tokens per vLLM iteration step.

  • kv_cache_dtype – Data type for KV cache (e.g. "fp8").

  • dtype – Model weight data type for vLLM (e.g. "bfloat16", "float16", "auto"). None = auto-resolved from model config.

  • quantization – vLLM quantization method. "auto" (default) = resolved from model_config → GPU defaults ("fp8" for H100/B200, None for A10). "fp8" = dynamically quantize weights to FP8_E4M3 at load time. None = don’t pass quantization to vLLM (auto-detect from model).

  • uc_volumes_path – UC Volume path for staging data. Required.

  • auto_build_prompt – Auto-build prompt column from model template.

  • output_column – Name of the output column containing model predictions. Defaults to "model_output".

  • pk_column – Name of an existing column to use as the row identifier for joining inference results back to the input. If None (default), a synthetic row ID is generated automatically. Use this when your input already has a unique primary key column to avoid creating a redundant one.

  • table_operation – How to write the output table — "overwrite" (default) replaces the table, "append" adds rows to an existing table.

  • ray_object_store_memory – Fraction of cluster memory for Ray object store. Defaults to 0.5.

  • http_pool_maxsize – Max concurrent HTTP connections for model download.

  • trigger_remote – If True, submit via run_remote(). If False, run directly on current cluster.

  • wait – If True and trigger_remote=True, block until completion and return the result dict.

  • compute_type"serverless" (default) or "classic".

  • capacity_reservation_id – AWS Capacity Block reservation ID (e.g. "cr-0abc1234def56789a"). Classic only. Configures the cluster with on-demand instances and the X-Databricks-AwsCapacityBlockId tag. Pass "auto" to fetch from the WORKSPACE_CONFIGURATION secret scope (keys GPU_CAPACITY_RESERVATION_ID and GPU_AVAILABILITY_ZONE).

  • availability_zone – AWS availability zone for the Capacity Block (e.g. "us-east-1a"). Required when using capacity_reservation_id (auto-resolved when "auto").

  • ray_kwargs – Dict of Ray-on-Spark tuning parameters forwarded to run_classic_inference (classic only). Supported keys: num_cpus_per_node, num_cpus_head_node, num_gpus_head_node.

  • env_vars – Optional dictionary of environment variables to set in each Ray worker process before inference starts. Ray actors do not inherit the driver’s environment, so diagnostic variables like VLLM_LOGGING_LEVEL=DEBUG must be forwarded explicitly. Defaults to None (no extra variables).

  • output_schema – Optional PySpark DataType describing the JSON shape the model is expected to produce (e.g. ArrayType(StructType([...]))). When set, the output table gets a parsed model_output struct column alongside the raw text in model_output_raw. Accepts a DataType instance, a DDL string (e.g. "array<struct<l1_category:string,...>>"), or the dict form produced by ml_toolkit.ops.helpers.mlflow.serialize_pyspark_schema(). Required for HuggingFace model IDs (no MLflow signature); for UC models, overrides whatever the MLflow signature declared.

  • chat_template_kwargs – Extra keyword arguments forwarded to the tokenizer’s apply_chat_template call when prompts are rendered. Use this to toggle model-family-specific flags such as {"enable_thinking": False} for Qwen3 (suppresses <think> blocks) or {"thinking": True} for DeepSeek-V3.1 / Granite 3.2. The toolkit does not interpret the dict — it splats the keys into the call, so callers are responsible for using the kwarg name their model expects. Defaults to None (preserves prior behaviour). Applied on both serverless and classic compute paths.

Returns:

  • RemoteInferenceRun when trigger_remote=True and wait=False.

  • Result dict when trigger_remote=True and wait=True.

  • None when trigger_remote=False.

Raises:

ValueError – If validation fails (table format, GPU config, etc.).

RemoteInferenceRun#

class ml_toolkit.functions.llm.inference.function.RemoteInferenceRun[source]#

Handle for a submitted remote inference job (serverless or classic).

get_result(polling_interval: int = 60) dict[str, Any][source]#

Block until the job completes and return the result dict.

is_complete() bool[source]#

Non-blocking check whether the job has finished.

property status: str[source]#

Current lifecycle/result state of the job.

Serving Endpoints#

The serving_endpoints submodule provides functions for deploying and managing Databricks serving endpoints for Unity Catalog LLM models. It supports single-version and multi-version deployments with automatic traffic management using predefined LLM serving standards.

LLM Serving Standards — five predefined tiers for endpoint configuration:

Standard

Min tokens/sec

Max tokens/sec

Use case

XSMALL

0

1,000

Dev / testing (scale-to-zero)

SMALL

1,000

5,000

Light production (default)

MEDIUM

5,000

10,000

Standard production

LARGE

10,000

30,000

High throughput

XLARGE

30,000

70,000

Maximum throughput

Quick start — deploy, add a version, and shift traffic#
from ml_toolkit.functions.llm import (
    deploy_model_serving_endpoint,
    add_model_version_to_endpoint,
    update_endpoint_traffic,
    query_serving_endpoint,
)

# 1. Deploy a new endpoint (name inferred as "tiny_llama")
endpoint = deploy_model_serving_endpoint(
    model_name="catalog.schema.tiny_llama",
    model_versions=[1],
    standard="XSMALL",
    cost_component_name="ml_research",
    wait_for_ready=True,
)

# 2. Add a second version with equal traffic
endpoint = add_model_version_to_endpoint(
    version=2,
    model_name="catalog.schema.tiny_llama",
    standard="XSMALL",
)

# 3. Shift traffic to the new version
update_endpoint_traffic(
    model_name="catalog.schema.tiny_llama",
    traffic_config={1: 20, 2: 80},
)

# 4. Query a specific version
result_df = query_serving_endpoint(
    df=spark.table("prompts"),
    endpoint_name="tiny_llama",
    version=2,
    prompt_column="text",
)

deploy_model_serving_endpoint#

ml_toolkit.functions.llm.serving_endpoints.interface.deploy_model_serving_endpoint(model_name: str, model_versions: List[int], endpoint_name: str | None = None, endpoint_scaling_size: ml_toolkit.functions.llm.serving_endpoints.constants.EndpointScalingSize = 'SMALL', provisioning_type: ml_toolkit.functions.llm.serving_endpoints.constants.ProvisioningType | None = None, traffic_config: Dict[int, int] | None = None, cost_component_name: str | None = None, tags: Dict[str, str] | None = None, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#

Deploy one or more versions of a Unity Catalog model to a NEW serving endpoint.

This function creates a new serving endpoint. If the endpoint already exists with any served models, it will raise MLOpsToolkitEndpointAlreadyExistsException. Use add_model_version_to_endpoint() to add versions to an existing endpoint.

Automatically determines whether to use provisioned throughput or GPU config by calling the Databricks get-model-optimization-info API. For provisioned throughput models, token values are rounded to the model’s chunk_size increment. For GPU models, the endpoint_scaling_size maps to a workload_type + workload_size combo.

:param : Example: "yd_ml_dpe.fine_tune_staging.qlora_tinyllama_model" :type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : Each entry corresponds to a version on the endpoint. Required.

Example: [1] for single version, [1, 2] for two versions. The served model names will be: v1, v2, etc.

:type : param model_versions: List of Unity Catalog model version numbers (integers) to deploy. :param : the endpoint name will be automatically inferred from the model name (the last part

after the final dot). Must not already exist with served models, or MLOpsToolkitEndpointAlreadyExistsException will be raised. Example: If model_name is "catalog.schema.my_model", endpoint_name defaults to "my_model"

:type : param endpoint_name: Optional name of the serving endpoint to create. If not provided, :param : "XSMALL", "SMALL", "MEDIUM", "LARGE", "XLARGE".

For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos: XSMALL=GPU_SMALL/Small, SMALL=GPU_MEDIUM/Small, MEDIUM=GPU_MEDIUM/Medium, LARGE=MULTIGPU_MEDIUM/Medium, XLARGE=GPU_MEDIUM_8/Large.

:type : param endpoint_scaling_size: Endpoint scaling size. Defaults to "SMALL". Must be one of: :param : the type is auto-detected by calling the Databricks optimization-info API.

Must be "PROVISIONED_THROUGHPUT" or "GPU".

:type : param provisioning_type: Optional override for provisioning type. If not provided, :param : to traffic percentages. If None, traffic is distributed equally.

Example: {1: 70, 2: 30} means version 1 gets 70%, version 2 gets 30%

:type : param traffic_config: Optional traffic distribution dict mapping version numbers :param : will attempt to determine from settings or environment, but will raise ValueError

if not found. Always specify this explicitly for production deployments.

:type : param cost_component_name: Required cost component for tracking. If not provided, :param : :type : param tags: Optional additional tags as key-value pairs :param : When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :type : param scale_to_zero_enabled: Whether to enable scale-to-zero for the endpoint (default: True). :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready before returning :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True

Return type:

Endpoint details dictionary

Examples

Deploy provisioned throughput model (auto-detected)#
from ml_toolkit.functions.llm.serving_endpoints import deploy_model_serving_endpoint

endpoint = deploy_model_serving_endpoint(
    model_name="catalog.schema.tiny_llama",
    model_versions=[1],
    endpoint_scaling_size="XSMALL",
    cost_component_name="ml_research",
    wait_for_ready=True,
)
Deploy GPU model with explicit provisioning type#
endpoint = deploy_model_serving_endpoint(
    model_name="catalog.schema.custom_model",
    model_versions=[1],
    endpoint_scaling_size="MEDIUM",
    provisioning_type="GPU",
    cost_component_name="data_integration_rd",
)
Deploy multiple versions with A/B testing#
endpoint = deploy_model_serving_endpoint(
    model_name="catalog.schema.llama_model",
    model_versions=[1, 2],
    traffic_config={1: 50, 2: 50},
    endpoint_scaling_size="LARGE",
    cost_component_name="ml_research",
)

add_model_version_to_endpoint#

ml_toolkit.functions.llm.serving_endpoints.interface.add_model_version_to_endpoint(version: int, model_name: str, endpoint_name: str | None = None, endpoint_scaling_size: ml_toolkit.functions.llm.serving_endpoints.constants.EndpointScalingSize = 'SMALL', provisioning_type: ml_toolkit.functions.llm.serving_endpoints.constants.ProvisioningType | None = None, traffic_percentage: int | None = None, redistribute_traffic: bool = True, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#

Add a new model version to an existing serving endpoint.

This function fetches the current endpoint configuration, auto-detects the model name and version from existing versions, and adds the new version with automatic traffic distribution.

Automatically determines whether to use provisioned throughput or GPU config. Validates that the new entity’s provisioning type matches the existing endpoint entities.

:param : The served model name will be: v{version} (e.g., v1, v2, v3) :type : param version: Integer version identifier (e.g., 1, 2, 3). :param : Used to infer the endpoint name if endpoint_name is not explicitly provided. :type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : the endpoint name will be automatically inferred from model_name. :type : param endpoint_name: Optional name of the existing serving endpoint. If not provided, :param : "XSMALL", "SMALL", "MEDIUM", "LARGE", "XLARGE".

For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos.

:type : param endpoint_scaling_size: Endpoint scaling size. Defaults to "SMALL". Must be one of: :param : the type is auto-detected by calling the Databricks optimization-info API.

Must be "PROVISIONED_THROUGHPUT" or "GPU".

:type : param provisioning_type: Optional override for provisioning type. If not provided, :param : If None and redistribute_traffic=True, traffic is distributed equally.

If specified, other versions’ traffic will be proportionally reduced.

:type : param traffic_percentage: Percentage of traffic to route to this new version (0-100). :param : If False, the new version gets 0% traffic (must manually call update_endpoint_traffic). :type : param redistribute_traffic: If True, automatically adjust traffic across all versions. :param : When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :type : param scale_to_zero_enabled: Whether to enable scale-to-zero for the new version (default: True). :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready before returning :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True

Return type:

Updated endpoint details dictionary

Examples

Add version using model_name (recommended)#
from ml_toolkit.functions.llm.serving_endpoints import add_model_version_to_endpoint

endpoint = add_model_version_to_endpoint(
    version=3,
    model_name="catalog.schema.my_model",
    endpoint_scaling_size="MEDIUM",
)
Add version with explicit provisioning type#
endpoint = add_model_version_to_endpoint(
    version=2,
    endpoint_name="production_endpoint",
    endpoint_scaling_size="LARGE",
    provisioning_type="GPU",
    traffic_percentage=10,
)

update_endpoint_traffic#

ml_toolkit.functions.llm.serving_endpoints.interface.update_endpoint_traffic(traffic_config: Dict[int, int], model_name: str, endpoint_name: str | None = None, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT) Dict[source]#

Update the traffic distribution for an existing multi-version endpoint.

This is useful for gradual rollouts, canary deployments, or A/B testing scenarios where you want to adjust traffic percentages without redeploying models.

:param : Must sum to 100. Example: {1: 30, 2: 70} :type : param traffic_config: Dict mapping version numbers to traffic percentages. :param : Used to infer the endpoint name if endpoint_name is not explicitly provided. :type : param model_name: Unity Catalog model path (catalog.schema.model). Required. :param : the endpoint name will be automatically inferred from model_name. :type : param endpoint_name: Optional name of the serving endpoint. If not provided, :param : :type : param wait_for_ready: Whether to wait for endpoint to be ready after update :param : :type : param timeout: Maximum seconds to wait if wait_for_ready is True

Return type:

Updated endpoint details dictionary

Examples

Gradually increase traffic to new version using model_name#
# Start: 90% version 1, 10% version 2
update_endpoint_traffic(
    model_name="catalog.schema.my_model",
    traffic_config={1: 90, 2: 10},
)

# After monitoring: 50% split
update_endpoint_traffic(
    model_name="catalog.schema.my_model",
    traffic_config={1: 50, 2: 50},
)

# Final: 100% version 2
update_endpoint_traffic(
    model_name="catalog.schema.my_model",
    traffic_config={1: 0, 2: 100},
)
Using explicit endpoint name#
update_endpoint_traffic(
    endpoint_name="my-endpoint",
    traffic_config={1: 50, 2: 50},
)

get_serving_endpoint#

ml_toolkit.functions.llm.serving_endpoints.interface.get_serving_endpoint(model_name: str, endpoint_name: str | None = None) Dict[source]#

Get details of a serving endpoint.

:param : Required parameter used to infer endpoint_name if not provided. :type : param model_name: Unity Catalog model name (catalog.schema.model_name). :param : automatically inferred from model_name. :type : param endpoint_name: Name of the endpoint. Optional - if not provided,

Return type:

Endpoint details dictionary containing configuration, state, and metadata

Examples

from ml_toolkit.functions.llm.serving_endpoints import get_serving_endpoint

# Endpoint name automatically inferred from model name
endpoint = get_serving_endpoint(model_name="catalog.schema.tiny_llama")
# Endpoint name inferred as "tiny_llama"
print(f"State: {endpoint['state']['ready']}")
print(f"Served entities: {endpoint['config']['served_entities']}")

# Or specify explicit endpoint name
endpoint = get_serving_endpoint(
    model_name="catalog.schema.tiny_llama",
    endpoint_name="my-custom-endpoint"
)

list_serving_endpoints#

ml_toolkit.functions.llm.serving_endpoints.interface.list_serving_endpoints(filter_tags: Dict[str, str] | None = None) List[Dict][source]#

List all serving endpoints in the workspace.

:param : ALL specified tags matching will be returned :type : param filter_tags: Optional dict of tags to filter by. Only endpoints with

Return type:

List of endpoint dictionaries

Examples

from ml_toolkit.functions.llm.serving_endpoints import list_serving_endpoints

# List all endpoints
endpoints = list_serving_endpoints()
for ep in endpoints:
    print(f"{ep['name']}: {ep['state']['ready']}")

# Filter by cost component
my_team_endpoints = list_serving_endpoints(
    filter_tags={"cost_component_name": "ml_research"}
)

# Filter by multiple tags
prod_endpoints = list_serving_endpoints(
    filter_tags={"env": "production", "team": "ai-platform"}
)

query_serving_endpoint#

ml_toolkit.functions.llm.serving_endpoints.helpers.query.query_serving_endpoint(df: pyspark.sql.DataFrame, endpoint_name: str, version: int | None = None, prompt_column: str = 'prompt', prompt_template: str = None, max_tokens: int = 512, temperature: float = 0.7, output_column: str = 'output') pyspark.sql.DataFrame[source]#

Query a specific model version in a multi-version serving endpoint.

This convenience function prepares the DataFrame and calls the UDTF, automatically routing to the specified version (v1, v2, v3, etc.).

Parameters:
  • df – Input DataFrame with a prompt column

  • endpoint_name – Name of the serving endpoint

  • version – Integer version identifier (e.g., 1, 2, 3). Defaults to None (latest version). Routes to served model “v{version}” (e.g., v1, v2, v3). If None, automatically resolves to the highest available version.

  • prompt_column – Name of the column containing prompts (default: “prompt”). Ignored if prompt_template is specified.

  • prompt_template – Optional template string with <<column>> placeholders. Use double angle brackets to reference DataFrame columns. Examples: “Classify: <<text>>”, “Extract entity: <<merchant_name>>”

  • max_tokens – Maximum output tokens (default: 512)

  • temperature – Sampling temperature (default: 0.7)

  • output_column – Name of the output column (default: “output”)

Returns:

{output_column}, start_timestamp, end_timestamp, error_message

Return type:

DataFrame with columns

Example

>>> from ml_toolkit.functions.llm.serving_endpoints.helpers.query import query_serving_endpoint
>>> from pyspark.sql import SparkSession
>>>
>>> spark = SparkSession.builder.getOrCreate()
>>> df = spark.table("input_table")
>>>
>>> # Query latest version (default)
>>> result_df = query_serving_endpoint(
...     df=df,
...     endpoint_name="my_endpoint",
...     prompt_column="text",
...     max_tokens=1000,
...     temperature=0.5
... )
>>>
>>> # Query specific version
>>> result_df = query_serving_endpoint(
...     df=df,
...     endpoint_name="my_endpoint",
...     version=1,
...     prompt_template="Classify this <<category>> text: <<content>>",
...     max_tokens=1000,
...     temperature=0.5
... )

Model Registration#

The model registration functions provide a complete workflow for logging and registering models into Databricks Unity Catalog via MLflow. Two dedicated paths are available:

  • HuggingFace LLMs — use register_LLM_model to log a model + tokenizer via mlflow.transformers.

  • Generic pyfunc models — use register_model to log any custom PythonModel via mlflow.pyfunc.

Both functions support tagging, descriptions, artifact copying, prompt attachment, and optional auto-deployment to an alias on registration.

Quick start — register an LLM and deploy to prod#
from ml_toolkit import register_LLM_model

model_info = register_LLM_model(
    model_name="catalog.schema.llama_3_1_8b",
    model=model,
    tokenizer=tokenizer,
    torch_dtype="bfloat16",
    model_config={"max_new_tokens": 512, "temperature": 0.1},
    params={"learning_rate": 2e-4, "epochs": 3},
    metrics={"train_loss": 0.85},
    model_tags={"team": "nlp"},
    target_alias="prod",
)
print(f"Registered version {model_info.version}")
Quick start — register a generic pyfunc model#
from ml_toolkit import register_model

model_info = register_model(
    model_name="catalog.schema.my_classifier",
    custom_model=MyCustomModel(),
    source_artifacts_dir="/dbfs/tmp/model_weights",
    pip_requirements=["scikit-learn==1.4.0"],
    target_alias="prod",
)

register_LLM_model#

register_model#

deploy_model#

log_artifacts#

log_artifact#

Model Loading#

The model loading functions handle fetching registered models from Unity Catalog or HuggingFace Hub, with built-in retry logic for transient network or cluster errors.

Quick start — load a registered UC model with retry#
from ml_toolkit import load_model_with_retry, load_auto_tokenizer_with_retry, get_model_info

# Load by UC model name (latest version by default)
model, tokenizer = load_model_with_retry(
    model_checkpoint="catalog.schema.llama_3_1_8b",
    return_type="components",
)

# Inspect model version metadata
info = get_model_info("catalog.schema.llama_3_1_8b")
info.display()

get_model_info#

load_model_with_retry#

load_auto_tokenizer_with_retry#