import json
from typing import Tuple
from pyspark.errors import IllegalArgumentException
from pyspark.sql.dataframe import DataFrame
import tiktoken
from yipit_databricks_client import get_spark_session
from yipit_databricks_client.dbutils import get_dbutils
from ml_toolkit.functions.llm.constants import (
APPROX_COST_PER_MODEL_SIZE,
DEFAULT_MAX_OUTPUT_TOKENS,
DEFAULT_MODEL_PROCESS,
GLOBAL_TOKEN_QUOTAS,
MODELS_TO_SIZE_MAP,
TOKEN_QUOTA_VOLUME_PATH,
)
from ml_toolkit.functions.llm.helpers.prompt import replace_prompt_placeholders
from ml_toolkit.ops.helpers.api import get_user_name
from ml_toolkit.ops.helpers.exceptions import MLOpsToolkitAboveTokenUsageQuota
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.storage.volumes import read_json_file_to_dict
def estimate_prompt_token_usage(prompt: str, model: str) -> int:
"""
Estimate the token usage for a given prompt and model.
"""
# tiktoken only works with openAI models
model_size = MODELS_TO_SIZE_MAP.get(model)
model_size_to_tokens = {
"small_models": "gpt-4o-mini",
"medium_models": "gpt-4o-mini",
"big_models": "gpt-4o",
}
model_to_estimate = model_size_to_tokens[model_size]
encoding = tiktoken.encoding_for_model(model_to_estimate)
return len(encoding.encode(prompt))
def estimate_cost_from_token_input_output(
input_tokens: int, output_tokens: int, model: str
) -> float:
"""
Estimate the cost for a given token usage and model.
"""
logger = get_logger()
model_size = MODELS_TO_SIZE_MAP.get(model)
model_costs = APPROX_COST_PER_MODEL_SIZE.get(model_size)
logger.info(f"Model: {model} is size {model_size} with costs {model_costs}")
return (
model_costs["input_tokens"] * input_tokens
+ model_costs["output_tokens"] * output_tokens
)
def estimate_net_token_usage(prompt: str, model: str, max_output_tokens: int) -> int:
"""
Estimate the net token usage for a given prompt and model.
"""
logger = get_logger()
input_tokens = estimate_prompt_token_usage(prompt, model)
output_tokens = max_output_tokens
cost = estimate_cost_from_token_input_output(input_tokens, output_tokens, model)
logger.info(
f"Estimated cost for {model} with {input_tokens} input tokens and {output_tokens} output tokens: {cost}$"
)
return input_tokens + 4 * output_tokens
def estimate_net_batch_token_usage(
prompt_source: str,
df: DataFrame,
max_output_tokens: int = 512,
model: str = "gpt-4o-mini",
) -> tuple[int, int, int]:
logger = get_logger()
num_rows = df.count()
rows_to_estimate = 50
avg_tokens_per_input_row = 0
for row in df.limit(rows_to_estimate).collect():
prompt = replace_prompt_placeholders(row.asDict(), prompt_source)
input_tokens = estimate_prompt_token_usage(prompt, model)
avg_tokens_per_input_row += input_tokens
avg_tokens_per_input_row /= rows_to_estimate
total_input_tokens_estimate = avg_tokens_per_input_row * num_rows
total_output_tokens_estimate = max_output_tokens * num_rows
total_net_tokens_estimate = (
total_input_tokens_estimate + 4 * total_output_tokens_estimate
)
cost = estimate_cost_from_token_input_output(
input_tokens=total_input_tokens_estimate,
output_tokens=total_output_tokens_estimate,
model=model,
)
logger.info(
f"""
Estimated cost for {model} running on {num_rows:,.0f} rows is: {cost:,.2f}$.
- total_input_tokens: {total_input_tokens_estimate:,.0f}
- max_output_tokens: {total_output_tokens_estimate:,.0f}
- net_tokens (input+4*output): {total_net_tokens_estimate:,.0f}
- estimated_cost: {cost:,.2f}$
"""
)
return (
total_input_tokens_estimate,
total_output_tokens_estimate,
total_net_tokens_estimate,
)
def get_user_quota(user: str, model: str) -> int:
"""
Get user quota for LLM tokens by model size. Resolve order is:
1. If user has a personal quota, use that
2. If workspace has a default quota, use that
3. If none of the above exist, use quota defined in code
Args:
user:
model: model name
"""
logger = get_logger()
model_size = MODELS_TO_SIZE_MAP[model]
all_users_quota = read_json_file_to_dict(TOKEN_QUOTA_VOLUME_PATH)
current_user_quota = all_users_quota.get(user, {})
actual_quota = current_user_quota.get(model_size, None)
if actual_quota is not None:
logger.debug("Using user-specific quota.")
else:
try:
dbutils = get_dbutils()
global_quotas = json.loads(
dbutils.secrets.get("WORKSPACE_CONFIGURATION", "GLOBAL_LLM_QUOTAS")
)
actual_quota = global_quotas.get(model_size, None)
except IllegalArgumentException:
pass
if actual_quota is not None:
logger.debug("Using workspace-specific quota.")
else:
actual_quota = GLOBAL_TOKEN_QUOTAS[model_size]
logger.debug("Using code-default quota.")
logger.info(
f"Your net-token quota is {actual_quota:,.0f} tokens (input + 4x output)."
)
return actual_quota
def raise_exception_if_above_token_usage_quota(net_token_count: int, model: str):
user_name = get_user_name()
user_quota = get_user_quota(user_name, model)
if net_token_count > user_quota:
logger = get_logger()
logger.debug(
f"Checking quota usage:\n"
f"estimated_net_tokens={net_token_count:,.0f}, net_tokens_quota={user_quota:,.0f}"
)
raise MLOpsToolkitAboveTokenUsageQuota(
estimated_net_tokens=net_token_count,
net_tokens_quota=user_quota,
)
[docs]
def estimate_token_usage(
data_source: str | DataFrame,
prompt_source: str = None,
max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS,
model: str = DEFAULT_MODEL_PROCESS,
) -> Tuple[int, int, int]:
"""
Estimates the token usage of a given run of ``run_llm_batch`` (data + prompt). Quotas are applied at the net-token
level, which is our internal metric to account for the fact that output tokens are 4x more expensive than input
tokens.
.. tip:: Reduce the number of ``max_output_tokens`` in order to greatly decrease cost.
:param data_source: DataFrame or Delta table name to run_llm_batch.
:param prompt_source: String of the prompt.
:param max_output_tokens: Maximum number of tokens the LLM can output.
:param model: Name of the LLM model to use.
:returns total_input_tokens_estimate: Estimate of total input token usage.
:returns total_output_tokens_estimate: Estimate of total output token usage.
:returns total_net_tokens_estimate: Estimate of total net token usage.
"""
if isinstance(data_source, str):
spark = get_spark_session()
data_source = spark.table(data_source)
(
total_input_tokens_estimate,
total_output_tokens_estimate,
total_net_tokens_estimate,
) = estimate_net_batch_token_usage(
prompt_source=prompt_source,
df=data_source,
model=model,
max_output_tokens=max_output_tokens,
)
return (
total_input_tokens_estimate,
total_output_tokens_estimate,
total_net_tokens_estimate,
)