Source code for ml_toolkit.functions.llm.helpers.token_usage

import json
from typing import Tuple

from pyspark.errors import IllegalArgumentException
from pyspark.sql.dataframe import DataFrame
import tiktoken
from yipit_databricks_client import get_spark_session
from yipit_databricks_client.dbutils import get_dbutils

from ml_toolkit.functions.llm.constants import (
    APPROX_COST_PER_MODEL_SIZE,
    DEFAULT_MAX_OUTPUT_TOKENS,
    DEFAULT_MODEL_PROCESS,
    GLOBAL_TOKEN_QUOTAS,
    MODELS_TO_SIZE_MAP,
    TOKEN_QUOTA_VOLUME_PATH,
)
from ml_toolkit.functions.llm.helpers.prompt import replace_prompt_placeholders
from ml_toolkit.ops.helpers.api import get_user_name
from ml_toolkit.ops.helpers.exceptions import MLOpsToolkitAboveTokenUsageQuota
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.storage.volumes import read_json_file_to_dict


def estimate_prompt_token_usage(prompt: str, model: str) -> int:
    """
    Estimate the token usage for a given prompt and model.
    """
    # tiktoken only works with openAI models
    model_size = MODELS_TO_SIZE_MAP.get(model)
    model_size_to_tokens = {
        "small_models": "gpt-4o-mini",
        "medium_models": "gpt-4o-mini",
        "big_models": "gpt-4o",
    }
    model_to_estimate = model_size_to_tokens[model_size]
    encoding = tiktoken.encoding_for_model(model_to_estimate)
    return len(encoding.encode(prompt))


def estimate_cost_from_token_input_output(
    input_tokens: int, output_tokens: int, model: str
) -> float:
    """
    Estimate the cost for a given token usage and model.
    """
    logger = get_logger()
    model_size = MODELS_TO_SIZE_MAP.get(model)
    model_costs = APPROX_COST_PER_MODEL_SIZE.get(model_size)
    logger.info(f"Model: {model} is size {model_size} with costs {model_costs}")

    return (
        model_costs["input_tokens"] * input_tokens
        + model_costs["output_tokens"] * output_tokens
    )


def estimate_net_token_usage(prompt: str, model: str, max_output_tokens: int) -> int:
    """
    Estimate the net token usage for a given prompt and model.
    """
    logger = get_logger()
    input_tokens = estimate_prompt_token_usage(prompt, model)
    output_tokens = max_output_tokens
    cost = estimate_cost_from_token_input_output(input_tokens, output_tokens, model)
    logger.info(
        f"Estimated cost for {model} with {input_tokens} input tokens and {output_tokens} output tokens: {cost}$"
    )
    return input_tokens + 4 * output_tokens


def estimate_net_batch_token_usage(
    prompt_source: str,
    df: DataFrame,
    max_output_tokens: int = 512,
    model: str = "gpt-4o-mini",
) -> tuple[int, int, int]:
    logger = get_logger()

    num_rows = df.count()
    rows_to_estimate = 50
    avg_tokens_per_input_row = 0
    for row in df.limit(rows_to_estimate).collect():
        prompt = replace_prompt_placeholders(row.asDict(), prompt_source)
        input_tokens = estimate_prompt_token_usage(prompt, model)
        avg_tokens_per_input_row += input_tokens
    avg_tokens_per_input_row /= rows_to_estimate

    total_input_tokens_estimate = avg_tokens_per_input_row * num_rows
    total_output_tokens_estimate = max_output_tokens * num_rows
    total_net_tokens_estimate = (
        total_input_tokens_estimate + 4 * total_output_tokens_estimate
    )

    cost = estimate_cost_from_token_input_output(
        input_tokens=total_input_tokens_estimate,
        output_tokens=total_output_tokens_estimate,
        model=model,
    )
    logger.info(
        f"""
        Estimated cost for {model} running on {num_rows:,.0f} rows is: {cost:,.2f}$.
        - total_input_tokens: {total_input_tokens_estimate:,.0f}
        - max_output_tokens: {total_output_tokens_estimate:,.0f}
        - net_tokens (input+4*output): {total_net_tokens_estimate:,.0f}
        - estimated_cost: {cost:,.2f}$
        """
    )

    return (
        total_input_tokens_estimate,
        total_output_tokens_estimate,
        total_net_tokens_estimate,
    )


def get_user_quota(user: str, model: str) -> int:
    """
    Get user quota for LLM tokens by model size. Resolve order is:
    1. If user has a personal quota, use that
    2. If workspace has a default quota, use that
    3. If none of the above exist, use quota defined in code

    Args:
        user:
        model: model name
    """
    logger = get_logger()
    model_size = MODELS_TO_SIZE_MAP[model]

    all_users_quota = read_json_file_to_dict(TOKEN_QUOTA_VOLUME_PATH)
    current_user_quota = all_users_quota.get(user, {})

    actual_quota = current_user_quota.get(model_size, None)
    if actual_quota is not None:
        logger.debug("Using user-specific quota.")
    else:
        try:
            dbutils = get_dbutils()
            global_quotas = json.loads(
                dbutils.secrets.get("WORKSPACE_CONFIGURATION", "GLOBAL_LLM_QUOTAS")
            )
            actual_quota = global_quotas.get(model_size, None)
        except IllegalArgumentException:
            pass
        if actual_quota is not None:
            logger.debug("Using workspace-specific quota.")
        else:
            actual_quota = GLOBAL_TOKEN_QUOTAS[model_size]
            logger.debug("Using code-default quota.")

    logger.info(
        f"Your net-token quota is {actual_quota:,.0f} tokens (input + 4x output)."
    )

    return actual_quota


def raise_exception_if_above_token_usage_quota(net_token_count: int, model: str):
    user_name = get_user_name()
    user_quota = get_user_quota(user_name, model)

    if net_token_count > user_quota:
        logger = get_logger()
        logger.debug(
            f"Checking quota usage:\n"
            f"estimated_net_tokens={net_token_count:,.0f}, net_tokens_quota={user_quota:,.0f}"
        )
        raise MLOpsToolkitAboveTokenUsageQuota(
            estimated_net_tokens=net_token_count,
            net_tokens_quota=user_quota,
        )


[docs] def estimate_token_usage( data_source: str | DataFrame, prompt_source: str = None, max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS, model: str = DEFAULT_MODEL_PROCESS, ) -> Tuple[int, int, int]: """ Estimates the token usage of a given run of ``run_llm_batch`` (data + prompt). Quotas are applied at the net-token level, which is our internal metric to account for the fact that output tokens are 4x more expensive than input tokens. .. tip:: Reduce the number of ``max_output_tokens`` in order to greatly decrease cost. :param data_source: DataFrame or Delta table name to run_llm_batch. :param prompt_source: String of the prompt. :param max_output_tokens: Maximum number of tokens the LLM can output. :param model: Name of the LLM model to use. :returns total_input_tokens_estimate: Estimate of total input token usage. :returns total_output_tokens_estimate: Estimate of total output token usage. :returns total_net_tokens_estimate: Estimate of total net token usage. """ if isinstance(data_source, str): spark = get_spark_session() data_source = spark.table(data_source) ( total_input_tokens_estimate, total_output_tokens_estimate, total_net_tokens_estimate, ) = estimate_net_batch_token_usage( prompt_source=prompt_source, df=data_source, model=model, max_output_tokens=max_output_tokens, ) return ( total_input_tokens_estimate, total_output_tokens_estimate, total_net_tokens_estimate, )