llm module#

The llm module of the ml_toolkit contains all the functions that enable LLM usage within Databricks notebooks. We expose three functions:

  • run_llm_batch: performs row-level LLM querying in our data, outputting their response to a new column.

  • estimate_token_usage: estimates token usage of the run_llm_batch function.

  • query: gives you the ability to directly query a model and give it context and tools to perform an action.

Attention

These functions are in an experimental phase and are subject to change. If you have any feedback, please submit it via our Jira Form.

Attention

You must pass a cost_component_name to the functions that call the LLMs, otherwise they will raise exceptions.

Quota Controls#

We have set strict usage controls, or quotas, to limit usage. This means you will only be able to run a set limit of tokens without having an exception being raised. LLMs can get very expensive quickly, so these limits are set to avoid high usage.

If you have a valid usecase that solves a business problem and needs a higher quota to run that with your data, please submit a ticket. The approval process for this involves getting your manager to approve your usecase and requested budget. When submitting this, always include the output from estimate_token_usage.

Available Models#

The following models are available:

  • databricks-meta-llama-3-1-8b-instruct

  • databricks-meta-llama-3-3-70b-instruct

  • databricks-meta-llama-3-1-405b-instruct

  • databricks-llama-4-maverick

  • databricks-claude-3-7-sonnet

  • gpt-4o

  • gpt-4o-mini

  • gpt-4o-search-preview (only in query)

  • gpt-4o-mini-search-preview (only in query)

Attention

We strongly suggest using the llama models, as they are considerably cheaper than the openAI models.

Functions#

run_llm_batch#

ml_toolkit.functions.llm.run_llm_batch(data_source: pyspark.sql.DataFrame | str, prompt_source: str = None, output_column_name: str = 'llm_output', output_table_name: str | None = None, output_structured_schema: dict | None = None, model: str = DEFAULT_MODEL_PROCESS, dry_run: bool = None, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS, wait_for_completion: bool = True, table_operation: Literal['overwrite', 'append'] = 'overwrite', primary_key_columns: List[str] | None = None, cost_component_name: str = None)[source]#

Performs a row-level operation on the input data (spark DataFrame or table) by passing the prompt_source (along with the row-data mentioned there) to an LLM model and writes the output to output_column_name.

There are two operations modes:

  1. dry_run=True: good for quick experimentation, POCs and prompt engineering; only works with less than 1k rows of data.

  2. dry_run=False: should be used for the production pipeline or when running over a lot of data. Writes data to output_table_name and can either overwrite or append (set via table_operation).

Caution

dry_run=True does incur usage and costs. It’s a mode designed to allow faster and cheaper experimentation. If all you want is to estimate usage on a bigger dataset, please run estimate_token_usage.

Attention

We highly recommend using the default model or, if something different is needed, to go for the databricks llama models. They offer the best token throughput and cost. Also, output_structured_schema is only available for llama models. This means you can provide a python dict with the desired schema of the LLM output and have that work out of the box.

Attention

You must pass a cost_component_name, otherwise this function will raise an exception.

Prompt Building#

Express clearly what the model’s goal is and how they should approach their task. Don’t be overly wordy, as the prompt is wrapped around all rows, so token count grows fast. The way you can reference your data (the columns of your dataframe) is with the <<col_name>> syntax. There’s an example below, but you can see more in the examples section.

Warning

Do not use simple quotes (‘), because they break our prompt formatting.

Parameters:#

param data_source:

DataFrame or Delta table name to run_llm_batch.

param prompt_source:

String of the prompt.

param output_column_name:

Name of the column to write the LLM output.

param output_table_name:

Optional Delta table to write results to.

param output_structured_schema:

Optional structured output dict (only available for llama models).

param model:

Name of the LLM model to use.

param max_output_tokens:

Maximum number of tokens the LLM can output.

param dry_run:

Whether to run the processing job locally or triggering a remote batch run.

param wait_for_completion:

Whether to wait for job completion (only applies in batch mode).

param table_operation:

Operation to perform on the output table.

param primary_key_columns:

Primary key columns for the output table.

param cost_component_name:

Name of the cost component.

returns:

Result of the processing job.

raises ValueError:

If output_table_name is not provided in batch mode or cost_component_name is missing.

raises MLOpsToolkitTooManyRowsForInteractiveUsage:

If dry_run=True with too many rows.

Examples#

Parsing a column to translate it’s content.#
from ml_toolkit.functions.llm import run_llm_batch

prompt = "You are an AI translator. Please translate the following text into english: <<text_col>>"
run_llm_batch(
    data_source="catalog.schema.input_table",
    prompt_source=prompt,
    output_table_name="catalog.schema.output_table",
    output_column_name="text_col_en",
    table_operation="overwrite",
    cost_component_name=...  # use your team's cost component here!
)
Using output_structured_schema#
import pyspark.sql.functions as F
from ml_toolkit.functions.llm import run_llm_batch

output_schema = {
    "name": "Error evaluation",
    "schema": {
        "type": "object",
        "properties": {
            "is_human_error": {"type": "boolean"},
            "confidence": {"type": "integer", "minimum": 0, "maximum": 10}
        }
    }
}

prompt = """
You are an expert python engineer. Your job is to look through error messages and output the source of the error and if the error looks like it came from a human error or not. Here is the record:
Error: <<error>>
"""

res = run_llm_batch(
    data_source=df,
    prompt_source=prompt,
    max_output_tokens=64,
    cost_component_name=...,  # use your team's cost component here!
    output_structured_schema=output_schema
)
df_llm = res["df"]
display(df_llm)

estimate_token_usage#

ml_toolkit.functions.llm.estimate_token_usage(data_source: str | pyspark.sql.dataframe.DataFrame, prompt_source: str = None, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS, model: str = DEFAULT_MODEL_PROCESS)[source]#

Estimates the token usage of a given run of run_llm_batch (data + prompt). Quotas are applied at the net-token level, which is our internal metric to account for the fact that output tokens are 4x more expensive than input tokens.

Tip

Reduce the number of max_output_tokens in order to greatly decrease cost.

Parameters:
  • data_source – DataFrame or Delta table name to run_llm_batch.

  • prompt_source – String of the prompt.

  • max_output_tokens – Maximum number of tokens the LLM can output.

  • model – Name of the LLM model to use.

Returns total_input_tokens_estimate:

Estimate of total input token usage.

Returns total_output_tokens_estimate:

Estimate of total output token usage.

Returns total_net_tokens_estimate:

Estimate of total net token usage.

query#

ml_toolkit.functions.llm.query(prompt: str, inputs: str | dict | pyspark.sql.dataframe.DataFrame | None = None, context_files: str | List[str] | None = None, model: str = DEFAULT_MODEL_QUERY, response_format: Literal['text', 'json'] = 'text', tools: list[Callable | dict] | None = None, max_tool_calls: int = None, max_output_tokens: int | None = 2048, system_prompt: str | None = None, cost_component_name: str | None = None, model_extra_options: dict = {})[source]#

Query a LLM model and get its response.

This function allows you to prompt an LLM and implements many resources to streamline usage. Some examples are: input data (including dataframes), files, tools, web search. Check the examples to see advanced usage of these features.

Parameters:
  • prompt – The main prompt text to send to the LLM model.

  • inputs – Optional input data that can be a string, dictionary, or DataFrame to provide context.

  • context_files – Optional file path(s) to include as context. Can be a single string or list of strings.

  • model – Optional name of the LLM model to use.

  • response_format – Format of the response - either “text” or “json”. Defaults to “text”.

  • tools – Optional list of functions or dicts that the LLM can use as tools.

  • max_tool_calls – Maximum number of tool calls allowed.

  • max_output_tokens – Maximum number of tokens in the response. Defaults to 2048.

  • system_prompt – Optional system prompt to set the behavior of the model.

  • cost_component_name – Name for tracking costs.

  • model_extra_options – Additional model-specific options as a dictionary.

Returns:

LLMResponse object containing the model’s response.

Raises:

ValueError – If there are errors reading context files.

Examples#

Querying a search-enabled model.#
from ml_toolkit.functions.llm import query

resp = query(
    "What are the latest info on the Trump Tariffs?",
    model="gpt-4o-mini-search-preview",
    model_extra_options={
        "web_search_options":{
            "user_location": {
                "type": "approximate",
                "approximate": {
                    "country": "US",
                }
            },
        }
    },
    cost_component_name=...  # use your team's cost component here!
)
print(resp.text)
Passing a dataframe as context and asking the model questions about our data.#
from ml_toolkit.functions.llm import query

df = spark.table(...).limit(10)
response = query(
    "How can I query this table to know the usage rates by hour over the last 21 days?",
    inputs=[df],
    response_format="text",
    cost_component_name=...  # use your team's cost component here!
)
print(response.text)
Passing functions to increase the models’ capabilities.#
from ml_toolkit.functions.llm import query

def sum_two_numbers(a: float, b: float) -> float:
    """
    Sums two numbers and returns the result.
    """
    return a + b

response = query(
    "What's the sum between 1.1234 and 4.4321?",
    tools=[sum_two_numbers],
    response_format="json",
    cost_component_name=...  # use your team's cost component here!
)
print(response.json())

LLMResponse#

class ml_toolkit.functions.llm.query.LLMResponse[source]#

Class that defines the LLM response of the query function. It exposes the following attributes

  • .text: returns the text response

  • response: returns the raw LLM response class (openai.ChatCompletion)

And also the following methods:

  • .json(): tries to parse the output into a python dict if possible

static __new__(cls, *args, **kwargs)[source]#

Only added here to avoid having this method appear on the docs. :meta private: