Source code for ml_toolkit.functions.llm.run_llm_batch.interface

from typing import List, Literal, Optional, Union
import uuid

from pyspark.sql import DataFrame
from yipit_databricks_client import get_spark_session
from yipit_databricks_client.helpers.telemetry import track_usage

from ml_toolkit.functions.llm.constants import (
    DEFAULT_MAX_OUTPUT_TOKENS,
    DEFAULT_MODEL_PROCESS,
    MAX_NUM_ROWS_INTERACTIVE_MODE,
)
from ml_toolkit.functions.llm.helpers.logging import log_model_usage
from ml_toolkit.functions.llm.helpers.prompt import find_cols_in_prompt_source
from ml_toolkit.functions.llm.helpers.token_usage import (
    estimate_net_batch_token_usage,
    raise_exception_if_above_token_usage_quota,
)
from ml_toolkit.functions.llm.run_llm_batch.function import (
    process_batch,
    process_interactive,
)
from ml_toolkit.ops.helpers.exceptions import MLOpsToolkitTooManyRowsForInteractiveUsage
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.validation import (
    assert_columns_exists_in_dataframe,
    assert_columns_exists_in_table,
    assert_table_exists,
    for_loop_guardrail,
)


def resolve_dry_run(data_source: DataFrame, dry_run: bool = None) -> bool:
    num_rows = data_source.count()
    if dry_run is True and num_rows > MAX_NUM_ROWS_INTERACTIVE_MODE:
        raise MLOpsToolkitTooManyRowsForInteractiveUsage(
            num_rows=num_rows, max_num_rows=MAX_NUM_ROWS_INTERACTIVE_MODE
        )

    if dry_run is None and num_rows <= MAX_NUM_ROWS_INTERACTIVE_MODE:
        dry_run = True
    elif dry_run is None and num_rows > MAX_NUM_ROWS_INTERACTIVE_MODE:
        dry_run = False

    return dry_run


@track_usage
@for_loop_guardrail(min_interval_seconds=10)
[docs] def run_llm_batch( data_source: Union[DataFrame, str], prompt_source: str = None, output_column_name: str = "llm_output", output_table_name: Optional[str] = None, output_structured_schema: Optional[dict] = None, model: str = DEFAULT_MODEL_PROCESS, dry_run: bool = None, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS, wait_for_completion: bool = True, table_operation: Literal["overwrite", "append"] = "overwrite", primary_key_columns: Optional[List[str]] = None, cost_component_name: str = None, ) -> dict: """ Performs a row-level operation on the input data (spark ``DataFrame`` or table) by passing the ``prompt_source`` (along with the row-data mentioned there) to an LLM model and writes the output to ``output_column_name``. There are two operations modes: 1. ``dry_run=True``: good for quick experimentation, POCs and prompt engineering; only works with less than 1k rows of data. 2. ``dry_run=False``: should be used for the production pipeline or when running over a lot of data. Writes data to ``output_table_name`` and can either ``overwrite`` or ``append`` (set via ``table_operation``). .. caution:: ``dry_run=True`` does incur usage and costs. It's a mode designed to allow faster and cheaper experimentation. If all you want is to estimate usage on a bigger dataset, please run ``estimate_token_usage``. .. attention:: We highly recommend using the default model or, if something different is needed, to go for the databricks llama models. They offer the best token throughput and cost. Also, ``output_structured_schema`` is only available for llama models. This means you can provide a python dict with the desired schema of the LLM output and have that work out of the box. .. attention:: You **must** pass a ``cost_component_name``, otherwise this function will raise an exception. Prompt Building ^^^^^^^^^^^^^^^^ Express clearly what the model's goal is and how they should approach their task. Don't be overly wordy, as the prompt is wrapped around all rows, so token count grows fast. The way you can reference your data (the columns of your dataframe) is with the ``<<col_name>>`` syntax. There's an example below, but you can see more in the examples section. .. code-block: python prompt = "Test prompt <<col_1>> and <<col_2>>." .. warning:: Do **not** use simple quotes ('), because they break our prompt formatting. Parameters: ^^^^^^^^^^ :param data_source: DataFrame or Delta table name to run_llm_batch. :param prompt_source: String of the prompt. :param output_column_name: Name of the column to write the LLM output. :param output_table_name: Optional Delta table to write results to. :param output_structured_schema: Optional structured output dict (only available for llama models). :param model: Name of the LLM model to use. :param max_output_tokens: Maximum number of tokens the LLM can output. :param dry_run: Whether to run the processing job locally or triggering a remote batch run. :param wait_for_completion: Whether to wait for job completion (only applies in batch mode). :param table_operation: Operation to perform on the output table. :param primary_key_columns: Primary key columns for the output table. :param cost_component_name: Name of the cost component. :returns: Result of the processing job. :raises ValueError: If output_table_name is not provided in batch mode or cost_component_name is missing. :raises MLOpsToolkitTooManyRowsForInteractiveUsage: If dry_run=True with too many rows. Examples ^^^^^^^^^^ .. code-block:: python :caption: Parsing a column to translate it's content. from ml_toolkit.functions.llm import run_llm_batch prompt = "You are an AI translator. Please translate the following text into english: <<text_col>>" run_llm_batch( data_source="catalog.schema.input_table", prompt_source=prompt, output_table_name="catalog.schema.output_table", output_column_name="text_col_en", table_operation="overwrite", cost_component_name=... # use your team's cost component here! ) .. code-block:: python :caption: Using ``output_structured_schema`` import pyspark.sql.functions as F from ml_toolkit.functions.llm import run_llm_batch output_schema = { "name": "Error evaluation", "schema": { "type": "object", "properties": { "is_human_error": {"type": "boolean"}, "confidence": {"type": "integer", "minimum": 0, "maximum": 10} } } } prompt = \""" You are an expert python engineer. Your job is to look through error messages and output the source of the error and if the error looks like it came from a human error or not. Here is the record: Error: <<error>> \""" res = run_llm_batch( data_source=df, prompt_source=prompt, max_output_tokens=64, cost_component_name=..., # use your team's cost component here! output_structured_schema=output_schema ) df_llm = res["df"] display(df_llm) """ logger = get_logger() # Generate a unique ID for this processing job source_uuid = str(uuid.uuid4()) cols = find_cols_in_prompt_source(prompt_source) if isinstance(data_source, str): spark = get_spark_session() assert_table_exists(data_source) assert_columns_exists_in_table(data_source, columns=cols) data_source = spark.table(data_source) else: assert_columns_exists_in_dataframe( data_source, columns=cols, table_name="dataframe" ) dry_run = resolve_dry_run(data_source, dry_run=dry_run) if dry_run is True: logger.info(f"Running job {source_uuid} with dry_run={dry_run}.") else: logger.info( f"Running job {source_uuid} with dry_run={dry_run} and table_operation={table_operation}." ) estimated_input_tokens, estimated_output_tokens, net_tokens = ( estimate_net_batch_token_usage( prompt_source=prompt_source, df=data_source, model=model, max_output_tokens=max_output_tokens, ) ) raise_exception_if_above_token_usage_quota(net_tokens, model) # Validate parameters based on mode if not dry_run: if output_table_name is None or output_table_name == "": raise ValueError( "`output_table_name` is required for batch processing mode." ) if output_column_name is None or output_column_name == "": raise ValueError( "`output_column_name` is required for batch processing mode." ) if cost_component_name is None or cost_component_name == "": raise ValueError("`cost_component_name` is obligatory.") df_llm = None if dry_run is True: df_llm = process_interactive( df=data_source, model=model, output_column_name=output_column_name, prompt_source=prompt_source, max_output_tokens=max_output_tokens, output_structured_schema=output_structured_schema, ) elif dry_run is False: df_llm = process_batch( df=data_source, model=model, output_table_name=output_table_name, output_column_name=output_column_name, output_structured_schema=output_structured_schema, table_operation=table_operation, prompt_source=prompt_source, max_output_tokens=max_output_tokens, ) result = { "source_uuid": source_uuid, "model": model, "input_tokens": estimated_input_tokens, "output_tokens": estimated_output_tokens, "cost_component_name": cost_component_name, "prompt_source": prompt_source, "output_table_name": output_table_name, "output_column_name": output_column_name, "df": df_llm, "error_message": None, } log_model_usage( source_uuid, model, estimated_input_tokens=estimated_input_tokens, estimated_output_tokens=estimated_output_tokens, actual_input_tokens=estimated_input_tokens, actual_output_tokens=estimated_output_tokens, cost_component_name=cost_component_name, prompt=prompt_source, num_rows=data_source.count(), ) return result