"""Main evaluation functions for executing model evaluations.
This module provides:
- run_evaluation: Endpoint-based inference + metrics + MLflow logging
- run_local_evaluation: Local GPU inference + metrics + MLflow logging
- run_experiment: Multi-model comparison (supports both endpoint and local models)
- run_local_variance_analysis: Sampling-based variance analysis for local models
"""
from __future__ import annotations
from datetime import datetime
from typing import Any, Callable
import mlflow
from pyspark.sql import functions as F
import yaml
import yipit_databricks_client as ydbc
from yipit_databricks_utils.helpers.delta import create_table
from ml_toolkit.functions.eval_utils.constants import (
BUILTIN_METRICS,
DEFAULT_BATCH_SIZE,
DEFAULT_MAX_CONCURRENT,
DEFAULT_MAX_OUTPUT_TOKENS,
DEFAULT_TIMEOUT_SECONDS,
MAX_ERROR_RATE_THRESHOLD,
ExperimentStatus,
Metric,
RunStatus,
)
from ml_toolkit.functions.eval_utils.helpers.config import (
DatasetConfig,
Experiment,
LocalModelConfig,
ModelConfig,
)
from ml_toolkit.functions.eval_utils.helpers.eval import (
_compute_llm_judge_metrics,
_compute_row_metrics,
_create_combined_results_table,
_generate_experiment_id,
_generate_parent_run_id,
_generate_run_id,
_get_mlflow_experiment_path,
_register_eval_run,
_register_experiment,
)
from ml_toolkit.functions.eval_utils.helpers.inference import run_inference
from ml_toolkit.functions.eval_utils.helpers.metrics import aggregate_all_metrics
from ml_toolkit.functions.eval_utils.helpers.types import (
EvalRunResult,
ExperimentResult,
LLMJudgeConfig,
RemoteExperimentRun,
_EvalRunSnapshot,
)
from ml_toolkit.ops.helpers.exceptions import (
AggregateExceptionHandler,
MLOpsToolkitTableNotFoundException,
)
from ml_toolkit.ops.helpers.logger import get_logger, suppress_table_creation_logs
from ml_toolkit.ops.helpers.validation import table_exists
from ml_toolkit.ops.storage.workspace import create_workspace_directory
[docs]
def run_evaluation(
eval_table: str | None = None,
*,
primary_key: str | None = None,
model_name: str | None = None,
endpoint: str | None = None,
litellm_model: str | None = None,
version: int | None = None,
prompt_template: str | None = None,
system_prompt: str | None = None,
prompt_registry_name: str | None = None,
prompt_version: str | int | None = None,
prompt_alias: str | None = None,
metrics: list[str | Callable] = ["latency", "token_count"],
llm_judge_config: LLMJudgeConfig | None = None,
mlflow_experiment: str | None = None,
run_name: str | None = None,
batch_size: int = DEFAULT_BATCH_SIZE,
max_concurrent: int = DEFAULT_MAX_CONCURRENT,
timeout_seconds: float = DEFAULT_TIMEOUT_SECONDS,
input_column: str | None = "input",
expected_output_column: str | None = "expected_output",
additional_context_columns: list[str] | None = None,
tags: dict[str, str] | None = None,
max_output_tokens: int | None = DEFAULT_MAX_OUTPUT_TOKENS,
litellm_model_kwargs: dict[str, Any] | None = None,
# Object-based configuration (values override defaults, but explicit params take precedence)
dataset: DatasetConfig | None = None,
model_config: ModelConfig | None = None,
is_nested: bool = False,
# Internal parameters for experiment tracking (not part of public API)
_experiment_id: str | None = None,
_model_name: str | None = None,
) -> EvalRunResult:
"""Execute an evaluation run against a model endpoint.
Runs inference on all rows in the eval table using ai_query or OpenAI-compatible
clients, computes specified metrics, and logs results to MLflow.
Parameters can be provided directly or via DatasetConfig/ModelConfig objects.
When both are provided, explicit parameters take precedence over object values.
Args:
eval_table: Fully qualified eval table name (catalog.schema.table).
Can be provided via dataset object instead.
primary_key: Name of the primary key column. Can be provided via dataset object.
model_name: Unity Catalog model name (e.g., 'catalog.schema.model_name').
If provided without endpoint, the endpoint name will be automatically
inferred from the model name. Either model_name/endpoint or litellm_model
must be specified. Can be provided via config object.
endpoint: Databricks Model Serving endpoint name. Optional if model_name
is provided (will be inferred). Either model_name/endpoint or litellm_model
must be specified (not both). Can be provided via config object.
litellm_model: LiteLLM model identifier (e.g., 'gpt-4', 'claude-3').
Either model_name/endpoint or litellm_model must be specified (not both).
Can be provided via config object.
version: Optional integer version for routing to a specific
model version in multi-version serving endpoints. Only used when
endpoint is specified. Example: 1, 2, 3. Defaults to None (latest version).
Can be provided via config object.
prompt_template: Template for constructing the user message using <<variable>>
syntax. Available variables: <<input>>, <<candidates>>, and any columns from
additional_context_columns. Defaults to "<<input>>". Can be provided via config object.
system_prompt: Optional system prompt for chat completion. Can be provided via config object.
prompt_registry_name: Fully qualified prompt name from MLflow Prompt Registry
(e.g., 'catalog.schema.prompt_name') or URI (e.g., 'prompts:/catalog.schema.prompt_name@alias').
If provided, loads prompt_template from registry (converts {{variable}} to <<variable>>).
Can be provided via config object.
prompt_version: Specific version to load from registry. Ignored if prompt_registry_name
includes version/alias or if prompt_alias is provided. Can be provided via config object.
prompt_alias: Alias to load from registry (e.g., 'production'). Takes precedence over prompt_version.
Can be provided via config object.
metrics: List of metrics to compute. Defaults to ['latency', 'token_count'].
Can include built-in metric names ('latency', 'token_count', 'exact_match',
'fuzzy_match', 'llm_judge') or custom scorer callables
(``Callable[[str, str, dict], float]``). Can be provided via config object.
llm_judge_config: Configuration for LLM-as-judge metric. Required
if 'llm_judge' is in metrics list. Can be provided via config object.
mlflow_experiment: MLflow experiment name. If not specified, defaults
to '/Evals/{schema_name}/{table_base_name}' derived from eval_table.
run_name: Optional name for the MLflow run. Defaults to
'{endpoint_or_model}_{timestamp}'. Can be provided via config.name.
batch_size: Number of rows to process per batch.
max_concurrent: Maximum concurrent requests to the endpoint.
timeout_seconds: Timeout per inference request.
input_column: Name of the input column. Defaults to 'input'. Can be provided via dataset object.
expected_output_column: Name of the expected output column for
comparison metrics. Set to None if not available. Can be provided via dataset object.
additional_context_columns: Additional columns to include in prompt
template context. Can be provided via dataset object.
tags: Optional tags to apply to the MLflow run. Merged with config.tags if both provided.
max_output_tokens: Maximum output tokens for model response. Can be provided via config object.
litellm_model_kwargs: Additional kwargs for LiteLLM model.
dataset: DatasetConfig object containing table reference and column configuration.
Values are used as defaults when explicit parameters are not provided.
model_config: ModelConfig object containing model and prompt configuration.
Values are used as defaults when explicit parameters are not provided.
Returns:
EvalRunResult containing:
- run_id: Unique identifier for this eval run
- mlflow_run_id: MLflow run ID
- results_table: Fully qualified name of results table
- metrics_summary: Dict of aggregated metrics
- row_count: Number of rows evaluated
- error_count: Number of failed inferences
Raises:
ValueError: If neither endpoint nor litellm_model specified.
ValueError: If both endpoint and litellm_model specified.
ValueError: If 'llm_judge' in metrics but llm_judge_config is None.
MLOpsToolkitTableNotFoundException: If eval_table doesn't exist.
MLOpsToolkitEvalRunFailedException: If error rate exceeds threshold.
Example:
>>> # Using model_name (endpoint automatically inferred):
>>> result = run_evaluation(
... eval_table="catalog.schema.vendor_eval_20260126",
... primary_key="row_id",
... model_name="catalog.schema.vendor_tagger_model",
... prompt_template="Tag the following vendor name: <<input>>",
... metrics=["latency", "token_count", "exact_match"],
... )
>>> # Using explicit endpoint:
>>> result = run_evaluation(
... eval_table="catalog.schema.vendor_eval_20260126",
... primary_key="row_id",
... endpoint="vendor-tagger-v1",
... prompt_template="Tag the following vendor name: <<input>>",
... metrics=["latency", "token_count", "exact_match"],
... )
>>> # Using version-based routing (defaults to latest version):
>>> result = run_evaluation(
... eval_table="catalog.schema.vendor_eval",
... model_name="catalog.schema.my_model",
... prompt_template="Extract entity: <<input>>",
... )
>>> # Using specific version:
>>> result = run_evaluation(
... eval_table="catalog.schema.vendor_eval",
... model_name="catalog.schema.my_model",
... version=1,
... prompt_template="Extract entity: <<input>>",
... )
>>> # Using DatasetConfig and ModelConfig objects:
>>> dataset = DatasetConfig(table="catalog.schema.vendor_eval", primary_key="row_id")
>>> model_config = ModelConfig(name="gpt-4o-test", litellm_model="gpt-4o")
>>> result = run_evaluation(dataset=dataset, model_config=model_config)
>>> # Using prompt registry:
>>> result = run_evaluation(
... eval_table="catalog.schema.vendor_eval_20260126",
... primary_key="row_id",
... endpoint="vendor-tagger-v1",
... prompt_registry_name="mycatalog.myschema.vendor_tagging_prompt",
... prompt_alias="production",
... metrics=["latency", "token_count", "exact_match"],
... )
"""
logger = get_logger()
spark = ydbc.get_spark_session()
with AggregateExceptionHandler() as exc_handler:
if dataset is None:
dataset = DatasetConfig(
table=eval_table,
primary_key=primary_key,
input_column=input_column,
expected_output_column=expected_output_column,
additional_context_columns=additional_context_columns,
)
# Generate default run_name before ModelConfig creation
if model_config is None and run_name is None:
model_name_for_default = model_name or endpoint or litellm_model
if model_name_for_default:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"{model_name_for_default}_{timestamp}"
if model_config is None:
# Generate a default name if run_name is None
# ModelConfig requires a non-empty name, so we generate one based on endpoint/model
config_name = run_name
if config_name is None:
# Generate name from model_name, endpoint, or litellm_model
temp_model_name = (
model_name or endpoint or litellm_model or "unknown_model"
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
config_name = f"{temp_model_name}_{timestamp}"
model_config = ModelConfig(
name=config_name,
model_name=model_name,
endpoint=endpoint,
litellm_model=litellm_model,
version=version,
prompt_template=prompt_template,
system_prompt=system_prompt,
prompt_registry_name=prompt_registry_name,
prompt_version=prompt_version,
prompt_alias=prompt_alias,
max_output_tokens=max_output_tokens,
tags=tags,
temperature=litellm_model_kwargs.get("temperature")
if litellm_model_kwargs
else None,
litellm_model_kwargs=litellm_model_kwargs,
)
if any(m == "llm_judge" for m in metrics) and llm_judge_config is None:
exc_handler.collect_raise(
ValueError,
"'llm_judge_config' is required when 'llm_judge' metric is specified.",
)
exc_handler.raise_if_any()
builtin_metrics = [
m for m in metrics if isinstance(m, str) and m in BUILTIN_METRICS
]
custom_scorers = [m for m in metrics if callable(m)]
# Generate IDs and paths
run_id = _generate_run_id()
mlflow_exp_path = mlflow_experiment or _get_mlflow_experiment_path(dataset.table)
# Ensure model_config was created successfully
if model_config is None:
raise ValueError(
"model_config is None. This should not happen - ModelConfig creation must have failed. "
"Check that endpoint or litellm_model is provided, and name is valid."
)
model_name = model_config.endpoint or model_config.litellm_model
if run_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"{model_name}_{timestamp}"
logger.info(f"Starting eval run: {run_id}")
logger.info(f"Eval table_name: {eval_table}")
logger.info(f"Model: {model_name}")
# Create snapshot of run configuration for logging
run_snapshot = _EvalRunSnapshot(
eval_table=dataset.table,
primary_key=primary_key,
endpoint=model_config.endpoint,
litellm_model=model_config.litellm_model,
prompt_template=model_config.prompt_template,
system_prompt=model_config.system_prompt,
metrics=builtin_metrics,
llm_judge_config=llm_judge_config,
batch_size=batch_size,
max_concurrent=max_concurrent,
timeout_seconds=timeout_seconds,
max_output_tokens=model_config.max_output_tokens,
input_column=dataset.input_column,
expected_output_column=dataset.expected_output_column,
additional_context_columns=dataset.additional_context_columns,
litellm_model_kwargs=model_config.litellm_model_kwargs,
)
df = spark.table(dataset.table)
row_count = df.count()
logger.info(f"Loaded {row_count} rows from eval table_name")
# Ensure workspace directory exists for MLflow experiment
parent_dir = "/".join(mlflow_exp_path.split("/")[:-1])
logger.info(f"MLflow experiment path: {mlflow_exp_path}, parent dir: {parent_dir}")
create_workspace_directory(parent_dir)
logger.info("Workspace directory created, setting MLflow experiment...")
mlflow.set_experiment(mlflow_exp_path)
logger.info("MLflow experiment set successfully")
with mlflow.start_run(run_name=run_name, nested=is_nested) as mlflow_run:
mlflow_run_id = mlflow_run.info.run_id
mlflow.log_params(
{
"eval_table": dataset.table,
"primary_key": dataset.primary_key,
"endpoint": model_config.endpoint,
"litellm_model": model_config.litellm_model,
"metrics": str(builtin_metrics),
"row_count": row_count,
}
)
# Log config as artifact
mlflow.log_dict(run_snapshot.to_dict(), "config.json")
if tags:
mlflow.set_tags(tags)
try:
# Run inference
logger.info("Running inference...")
df_with_output = run_inference(
df=df,
endpoint=model_config.endpoint,
litellm_model=model_config.litellm_model,
prompt_column=dataset.input_column,
prompt_template=model_config.prompt_template,
system_prompt=model_config.system_prompt,
max_output_tokens=model_config.max_output_tokens,
output_column="model_output",
additional_context_columns=dataset.additional_context_columns,
litellm_model_kwargs=model_config.litellm_model_kwargs,
version=model_config.version,
)
# Compute metrics
logger.info("Computing metrics...")
df_with_metrics = _compute_row_metrics(
df=df_with_output,
metrics=builtin_metrics,
expected_output_column=dataset.expected_output_column,
llm_judge_config=llm_judge_config,
custom_scorers=custom_scorers,
)
# Compute LLM judge if requested
if "llm_judge" in builtin_metrics and llm_judge_config is not None:
logger.info("Computing LLM judge scores...")
df_with_metrics = _compute_llm_judge_metrics(
df=df_with_metrics,
config=llm_judge_config,
expected_output_column=dataset.expected_output_column,
)
# Add metadata columns
df_with_metrics = df_with_metrics.withColumn(
"created_at", F.current_timestamp()
)
# Drop internal columns that may contain NullType (not supported by Parquet)
internal_columns_to_drop = ["_ai_result"]
for col in internal_columns_to_drop:
if col in df_with_metrics.columns:
df_with_metrics = df_with_metrics.drop(col)
# Count errors
error_count = df_with_metrics.filter(
F.col("error_message").isNotNull()
).count()
error_rate = error_count / row_count if row_count > 0 else 0
# Check error threshold
if error_rate > MAX_ERROR_RATE_THRESHOLD:
logger.error(
f"Eval run failed: error rate {error_rate} exceeds threshold {MAX_ERROR_RATE_THRESHOLD}"
)
# Save results to Delta table_name
table_parts = dataset.table.split(".")
results_table_name = f"{table_parts[2]}_results_{run_id}"
results_full_table = (
f"{table_parts[0]}.{table_parts[1]}.{results_table_name}"
)
logger.info(f"Saving results to: {results_full_table}")
with suppress_table_creation_logs():
create_table(
schema_name=table_parts[1],
table_name=results_table_name,
query=df_with_metrics,
catalog_name=table_parts[0],
overwrite=False,
)
# Build metric columns from user-specified metrics
# latency_ms is always present from inference; other metrics have _score suffix
metric_columns = ["latency_ms"]
for m in builtin_metrics:
if m not in ("latency", "token_count"):
metric_columns.append(f"{m}_score")
# Collect results for aggregation
results_rows = df_with_metrics.select(*metric_columns).collect()
results_dicts = [row.asDict() for row in results_rows]
metrics_summary = aggregate_all_metrics(results_dicts, metric_columns)
# Add error metrics
metrics_summary["error_count"] = error_count
metrics_summary["error_rate"] = error_rate
# Log metrics to MLflow
mlflow.log_metrics(metrics_summary)
# Log results table_name reference
mlflow.log_text(results_full_table, "results_table_ref.txt")
# Register run
_register_eval_run(
run_id=run_id,
eval_table=dataset.table,
config=run_snapshot,
results_table=results_full_table,
mlflow_experiment=mlflow_exp_path,
mlflow_run_id=mlflow_run_id,
metrics_summary=metrics_summary,
row_count=row_count,
error_count=error_count,
status=RunStatus.COMPLETED,
run_name=run_name,
experiment_id=_experiment_id,
model_name=_model_name,
)
logger.info(f"Eval run completed: {run_id}")
logger.info(f"Results table_name: {results_full_table}")
logger.info(f"Metrics: {metrics_summary}")
return EvalRunResult(
run_id=run_id,
mlflow_run_id=mlflow_run_id,
results_table=results_full_table,
metrics_summary=metrics_summary,
row_count=row_count,
error_count=error_count,
created_at=datetime.now(),
)
except Exception as e:
logger.error(f"Eval run failed: {e}")
mlflow.log_param("status", RunStatus.FAILED)
mlflow.log_param("error", str(e))
raise
[docs]
def run_experiment(
experiment: Experiment,
*,
mlflow_experiment_name: str | None = None,
trigger_remote: bool = False,
wait: bool = False,
) -> ExperimentResult | RemoteExperimentRun:
"""Run an experiment comparing multiple models against a single dataset.
Creates a parent MLflow run with nested child runs for each model.
All results are stored in individual tables plus a combined table
with a 'model_name' column for easy comparison.
Metrics are defined once at the experiment level and applied to all models,
eliminating duplication.
Args:
experiment: Experiment object containing dataset, models, and metrics.
mlflow_experiment_name: MLflow experiment name. If not specified, defaults
to '/Evals/{schema_name}/{table_base_name}' derived from dataset table.
trigger_remote: If True, submit the experiment as a remote Databricks
serverless job instead of running locally. The experiment must not
contain callable metrics. Defaults to False.
wait: Only used when trigger_remote=True. If True, block and poll
until the remote job completes, then return a full ExperimentResult.
If False (default), return immediately with a RemoteExperimentRun.
Returns:
ExperimentResult if running locally or with trigger_remote=True and wait=True.
RemoteExperimentRun if trigger_remote=True and wait=False (fire-and-forget).
Raises:
MLOpsToolkitTableNotFoundException: If dataset table doesn't exist.
ValueError: If trigger_remote=True and experiment contains callable metrics.
Example:
>>> dataset = DatasetConfig(
... table="catalog.schema.vendor_eval",
... primary_key="row_id",
... input_column="vendor_name",
... expected_output_column="canonical_name",
... )
>>> experiment = Experiment(
... name="vendor-tagger-comparison",
... dataset=dataset,
... models=[
... ModelConfig(name="llama-8b", endpoint="llama-endpoint"),
... ModelConfig(name="gpt-4o", litellm_model="gpt-4o"),
... ],
... metrics=[Metric.LATENCY, Metric.EXACT_MATCH],
... )
>>> result = run_experiment(experiment)
>>> result.summary_df.orderBy("exact_match_accuracy", ascending=False).show()
>>> # Run remotely (fire-and-forget):
>>> remote_run = run_experiment(experiment, trigger_remote=True)
>>> print(remote_run.databricks_url)
>>> result = remote_run.get_result() # blocks until done
"""
# --- Remote execution path ---
if trigger_remote:
from ml_toolkit.functions.eval_utils.helpers.remote import (
_submit_remote_experiment,
_validate_experiment_serializable,
)
_validate_experiment_serializable(experiment)
remote_run = _submit_remote_experiment(
experiment=experiment,
mlflow_experiment_name=mlflow_experiment_name,
)
if wait:
return remote_run.get_result()
return remote_run
# --- Local execution path ---
logger = get_logger()
spark = ydbc.get_spark_session()
dataset = experiment.dataset
# Validate inputs
with AggregateExceptionHandler() as exc_handler:
if not table_exists(dataset.table):
exc_handler.collect_raise(MLOpsToolkitTableNotFoundException, dataset.table)
exc_handler.raise_if_any()
# Generate IDs and paths
experiment_id = _generate_experiment_id()
parent_run_id = _generate_parent_run_id()
mlflow_exp_path = mlflow_experiment_name or _get_mlflow_experiment_path(dataset)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
parent_run_name = f"{experiment.name}_{timestamp}"
dataset_hash = dataset.compute_dataset_hash(spark)
row_count = spark.table(dataset.table).count()
experiment_started_at = datetime.now()
logger.info(f"Starting experiment: {experiment.name} ({experiment_id})")
logger.info(f"Dataset: {dataset.table} (hash: {dataset_hash})")
logger.info(f"Models: {[m.name for m in experiment.models]}")
logger.info(f"Metrics: {[str(m) for m in experiment.metrics if not callable(m)]}")
# Ensure workspace directory exists for MLflow experiment
parent_dir = "/".join(mlflow_exp_path.split("/")[:-1])
logger.info(f"MLflow experiment path: {mlflow_exp_path}, parent dir: {parent_dir}")
create_workspace_directory(parent_dir)
logger.info("Workspace directory created, setting MLflow experiment...")
mlflow.set_experiment(mlflow_exp_path)
logger.info("MLflow experiment set successfully")
model_results: dict[str, EvalRunResult] = {}
# Start parent run
with mlflow.start_run(run_name=parent_run_name) as parent_run:
parent_mlflow_run_id = parent_run.info.run_id
mlflow.log_params(
{
"experiment_name": experiment.name,
"dataset_table": dataset.table,
"dataset_hash": dataset_hash,
"num_models": len(experiment.models),
"model_names": str([m.name for m in experiment.models]),
"primary_key": dataset.primary_key,
}
)
mlflow.log_dict(experiment.to_dict(), "experiment_info.json")
# Save experiment config as YAML artifact for reproducibility
yaml_content = yaml.dump(
experiment.to_dict(), default_flow_style=False, sort_keys=False
)
mlflow.log_text(yaml_content, "experiment_config.yaml")
mlflow.set_tags(
{
"eval_type": "experiment",
"experiment_id": experiment_id,
"experiment_name": experiment.name,
"parent_run_id": parent_run_id,
**(experiment.tags or {}),
}
)
# Run each model as a nested run
for model_cfg in experiment.models:
logger.info(f"Running model: {model_cfg.name}")
experiment_tags = {
"model_name": model_cfg.name,
"experiment_id": experiment_id,
"experiment_name": experiment.name,
"parent_run_id": parent_run_id,
**(model_cfg.tags or {}),
}
try:
if isinstance(model_cfg, LocalModelConfig):
# Local GPU model — run inference + eval locally
result = run_local_evaluation(
dataset=dataset,
local_model_config=model_cfg,
tags=experiment_tags,
metrics=experiment.metrics,
llm_judge_config=experiment.llm_judge_config,
is_nested=True,
_experiment_id=experiment_id,
_model_name=model_cfg.name,
)
else:
# Endpoint/LiteLLM model — use existing run_evaluation
result = run_evaluation(
dataset=dataset,
model_config=model_cfg,
tags=experiment_tags,
metrics=experiment.metrics,
llm_judge_config=experiment.llm_judge_config,
is_nested=True,
_experiment_id=experiment_id,
_model_name=model_cfg.name,
)
model_results[model_cfg.name] = result
except Exception as e:
logger.error(f"Model '{model_cfg.name}' failed: {e}")
raise
combined_results_table = _create_combined_results_table(
results=model_results,
dataset=dataset,
parent_run_id=parent_run_id,
)
# Log summary metrics to parent run
summary_metrics = {}
for model_name, result in model_results.items():
for metric_name, value in result.metrics_summary.items():
summary_metrics[f"{model_name}_{metric_name}"] = value
mlflow.log_metrics(summary_metrics)
mlflow.log_text(combined_results_table, "combined_results_table_ref.txt")
# Determine experiment status
experiment_completed_at = datetime.now()
experiment_duration = (
experiment_completed_at - experiment_started_at
).total_seconds()
if len(model_results) == len(experiment.models):
experiment_status = ExperimentStatus.COMPLETED
elif len(model_results) > 0:
experiment_status = ExperimentStatus.PARTIAL
else:
experiment_status = ExperimentStatus.FAILED
# Get metrics list (filter out callables)
metrics_enabled = [str(m) for m in experiment.metrics if not callable(m)]
# Register experiment in the registry
_register_experiment(
experiment_id=experiment_id,
experiment_name=experiment.name,
eval_table=dataset.table,
dataset_hash=dataset_hash,
config=experiment.to_dict(),
metrics_enabled=metrics_enabled,
mlflow_experiment=mlflow_exp_path,
parent_mlflow_run_id=parent_mlflow_run_id,
results_table=combined_results_table,
num_models=len(experiment.models),
row_count=row_count,
status=experiment_status,
created_at=experiment_started_at,
completed_at=experiment_completed_at,
duration_seconds=experiment_duration,
description=experiment.description,
tags=experiment.tags,
)
logger.info(f"Experiment completed: {experiment.name} ({experiment_id})")
logger.info(f"Combined results: {combined_results_table}")
return ExperimentResult(
experiment_id=experiment_id,
experiment_name=experiment.name,
parent_run_id=parent_mlflow_run_id,
mlflow_experiment=mlflow_exp_path,
dataset_hash=dataset_hash,
results_table=combined_results_table,
model_results=model_results,
)
def run_local_evaluation(
eval_table: str | None = None,
*,
primary_key: str | None = None,
model: Any = None,
tokenizer: Any = None,
prompt_column: str = "prompt",
expected_output_column: str | None = "expected_output",
batch_size: int = 16,
max_seq_len: int = 2048,
max_new_tokens: int = 64,
do_sample: bool = False,
temperature: float | None = None,
top_p: float | None = None,
seed: int | None = None,
output_parser: Callable[[str], list[str]] | None = None,
model_name: str | None = None,
metrics: list[str | Callable] | None = None,
llm_judge_config: LLMJudgeConfig | None = None,
mlflow_experiment: str | None = None,
run_name: str | None = None,
tags: dict[str, str] | None = None,
# Config objects (override flat params when provided)
dataset: DatasetConfig | None = None,
local_model_config: LocalModelConfig | None = None,
is_nested: bool = False,
_experiment_id: str | None = None,
_model_name: str | None = None,
) -> EvalRunResult:
"""Execute a local GPU evaluation run.
Runs inference on all rows in the eval table using a local PyTorch model
via ``run_pandas_batch_inference``, computes specified metrics using the
shared metric pipeline, and logs results to MLflow.
Parameters can be provided directly or via DatasetConfig/LocalModelConfig
objects. When both are provided, config objects take precedence.
Args:
eval_table: Fully qualified eval table name (catalog.schema.table).
Can be provided via dataset object instead.
primary_key: Name of the primary key column. Can be provided via
dataset object.
model: PyTorch model with ``.generate()`` (e.g. HuggingFace
``AutoModelForCausalLM``). Can be provided via local_config.
tokenizer: Corresponding tokenizer. Can be provided via local_config.
prompt_column: Column containing prompt text. Defaults to 'prompt'.
Can be provided via local_config.
expected_output_column: Name of the expected output column for
comparison metrics. Can be provided via dataset object.
batch_size: Number of prompts per generation batch.
max_seq_len: Maximum input sequence length for tokenization.
max_new_tokens: Maximum new tokens to generate.
do_sample: Whether to use sampling (True) or greedy (False).
temperature: Sampling temperature (only used when ``do_sample=True``).
top_p: Nucleus sampling probability (only used when ``do_sample=True``).
seed: Random seed for reproducibility.
output_parser: Callable that parses raw output string to entity list.
model_name: Model name for logging.
metrics: List of metrics to compute. Defaults to entity extraction
metrics. Can include built-in metric names or custom scorer
callables.
llm_judge_config: Configuration for LLM-as-judge metric.
mlflow_experiment: MLflow experiment path.
run_name: Optional name for the MLflow run.
tags: Optional tags to apply to the MLflow run.
dataset: DatasetConfig object containing table reference and column
configuration. Overrides flat params when provided.
local_config: LocalModelConfig object containing model and generation
configuration. Overrides flat params when provided.
is_nested: Whether this is a nested MLflow run.
Returns:
EvalRunResult containing run_id, mlflow_run_id, results_table,
metrics_summary, row_count, and error_count.
Example:
>>> # Using flat parameters:
>>> result = run_local_evaluation(
... eval_table="catalog.schema.eval_data",
... primary_key="id",
... model=model,
... tokenizer=tokenizer,
... prompt_column="prompt",
... expected_output_column="entities_gold",
... metrics=["entity_f1", "entity_exact_match"],
... )
>>> # Using config objects:
>>> dataset = DatasetConfig(
... table="catalog.schema.eval_data",
... primary_key="id",
... expected_output_column="entities_gold",
... )
>>> local_config = LocalModelConfig(
... name="finetuned-llama",
... model=model,
... tokenizer=tokenizer,
... max_new_tokens=128,
... )
>>> result = run_local_evaluation(
... dataset=dataset,
... local_config=local_config,
... metrics=["entity_f1", "entity_precision"],
... )
"""
from ml_toolkit.ml.llm.inference import run_pandas_batch_inference
logger = get_logger()
spark = ydbc.get_spark_session()
# --- Build config objects from flat params if not provided ---
if dataset is None:
dataset = DatasetConfig(
table=eval_table,
primary_key=primary_key,
expected_output_column=expected_output_column,
)
if local_model_config is None:
config_name = model_name or run_name or "local_model"
local_model_config = LocalModelConfig(
name=config_name,
model=model,
tokenizer=tokenizer,
prompt_column=prompt_column,
batch_size=batch_size,
max_seq_len=max_seq_len,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
seed=seed,
output_parser=output_parser,
tags=tags,
)
# Use config name for model_name if not explicitly provided
effective_model_name = _model_name or model_name or local_model_config.name
# Default metrics: entity extraction
if metrics is None:
metrics = [
Metric.ENTITY_PRECISION,
Metric.ENTITY_RECALL,
Metric.ENTITY_F1,
Metric.ENTITY_JACCARD,
Metric.ENTITY_EXACT_MATCH,
Metric.ENTITY_EXACT_MATCH_CI,
Metric.ENTITY_RELAXED_MATCH,
]
builtin_metrics = [
str(m) for m in metrics if isinstance(m, str) and str(m) in BUILTIN_METRICS
]
custom_scorers = [m for m in metrics if callable(m)]
# Generate IDs and paths
run_id = _generate_run_id()
mlflow_exp_path = mlflow_experiment or _get_mlflow_experiment_path(dataset)
if run_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"{effective_model_name}_{timestamp}"
logger.info(f"Starting local eval run: {run_id}")
logger.info(f"Eval table: {dataset.table}")
logger.info(f"Model: {effective_model_name}")
# --- Load eval data and run inference ---
eval_df = spark.table(dataset.table)
row_count = eval_df.count()
logger.info(f"Loaded {row_count} rows from eval table")
logger.info("Running local GPU inference...")
predictions_df = run_pandas_batch_inference(
model=local_model_config.model,
tokenizer=local_model_config.tokenizer,
eval_df=eval_df,
prompt_column=local_model_config.prompt_column,
gold_column=dataset.expected_output_column,
id_column=dataset.primary_key,
batch_size=local_model_config.batch_size,
max_seq_len=local_model_config.max_seq_len,
max_new_tokens=local_model_config.max_new_tokens,
do_sample=local_model_config.do_sample,
temperature=local_model_config.temperature,
top_p=local_model_config.top_p,
seed=local_model_config.seed,
output_parser=local_model_config.output_parser,
)
# Add model_output column for text-based metric compatibility
# (raw_output is the JSON string representation of predictions)
if "raw_output" in predictions_df.columns:
predictions_df = predictions_df.withColumn("model_output", F.col("raw_output"))
# --- Compute metrics using shared pipeline ---
logger.info("Computing metrics...")
df_with_metrics = _compute_row_metrics(
df=predictions_df,
metrics=builtin_metrics,
expected_output_column="entities_gold",
llm_judge_config=llm_judge_config,
custom_scorers=custom_scorers,
)
# Compute LLM judge if requested
if "llm_judge" in builtin_metrics and llm_judge_config is not None:
logger.info("Computing LLM judge scores...")
df_with_metrics = _compute_llm_judge_metrics(
df=df_with_metrics,
config=llm_judge_config,
expected_output_column="entities_gold",
)
# Add metadata
df_with_metrics = df_with_metrics.withColumn("created_at", F.current_timestamp())
# --- Aggregate metrics ---
# Build metric columns list for aggregation
metric_columns = []
for m in builtin_metrics:
if m.startswith("entity_"):
metric_columns.append(f"{m}_score")
elif m not in ("latency", "token_count"):
metric_columns.append(f"{m}_score")
# Add entity TP/FP/FN columns if any entity metrics requested
if any(m.startswith("entity_") for m in builtin_metrics):
metric_columns.extend(["entity_tp", "entity_fp", "entity_fn"])
# Collect and aggregate
if metric_columns:
available_columns = [c for c in metric_columns if c in df_with_metrics.columns]
if available_columns:
results_rows = df_with_metrics.select(*available_columns).collect()
results_dicts = [row.asDict() for row in results_rows]
metrics_summary = aggregate_all_metrics(results_dicts, available_columns)
else:
metrics_summary = {}
else:
metrics_summary = {}
logger.info(f"Metrics: {metrics_summary}")
# --- MLflow logging ---
parent_dir = "/".join(mlflow_exp_path.split("/")[:-1])
create_workspace_directory(parent_dir)
mlflow.set_experiment(mlflow_exp_path)
with mlflow.start_run(run_name=run_name, nested=is_nested) as mlflow_run:
mlflow_run_id = mlflow_run.info.run_id
mlflow.log_params(
{
"eval_table": dataset.table,
"primary_key": dataset.primary_key,
"model_name": effective_model_name,
"eval_type": "local",
"metrics": str(builtin_metrics),
"row_count": row_count,
"batch_size": local_model_config.batch_size,
"max_new_tokens": local_model_config.max_new_tokens,
"do_sample": local_model_config.do_sample,
}
)
# Log config as artifact
mlflow.log_dict(local_model_config.to_dict(), "config.json")
mlflow.log_metrics(metrics_summary)
merged_tags = {
"eval_type": "local",
**(local_model_config.tags or {}),
**(tags or {}),
}
mlflow.set_tags(merged_tags)
# --- Persist results ---
table_parts = dataset.table.split(".")
results_table_name = f"{table_parts[2]}_results_{run_id}"
results_full_table = f"{table_parts[0]}.{table_parts[1]}.{results_table_name}"
logger.info(f"Saving results to: {results_full_table}")
with suppress_table_creation_logs():
create_table(
schema_name=table_parts[1],
table_name=results_table_name,
query=df_with_metrics,
catalog_name=table_parts[0],
overwrite=False,
)
# Register run
run_snapshot = _EvalRunSnapshot(
eval_table=dataset.table,
primary_key=dataset.primary_key,
endpoint=None,
litellm_model=None,
prompt_template=None,
system_prompt=None,
metrics=builtin_metrics,
llm_judge_config=llm_judge_config,
max_output_tokens=local_model_config.max_new_tokens,
input_column=local_model_config.prompt_column,
expected_output_column=dataset.expected_output_column,
)
_register_eval_run(
run_id=run_id,
eval_table=dataset.table,
config=run_snapshot,
results_table=results_full_table,
mlflow_experiment=mlflow_exp_path,
mlflow_run_id=mlflow_run_id,
metrics_summary=metrics_summary,
row_count=row_count,
error_count=0,
status=RunStatus.COMPLETED,
run_name=run_name,
experiment_id=_experiment_id,
model_name=_model_name,
)
logger.info(f"Local eval run completed: {run_id}")
logger.info(f"Results table: {results_full_table}")
return EvalRunResult(
run_id=run_id,
mlflow_run_id=mlflow_run_id,
results_table=results_full_table,
metrics_summary=metrics_summary,
row_count=row_count,
error_count=0,
created_at=datetime.now(),
)
def run_local_variance_analysis(
local_model_config: LocalModelConfig,
dataset: DatasetConfig,
*,
baseline_result: EvalRunResult,
metrics: list[str | Callable] | None = None,
n_reruns: int = 10,
base_seed: int = 12345,
temperature: float = 0.2,
top_p: float = 0.9,
mlflow_experiment: str | None = None,
tags: dict[str, str] | None = None,
) -> dict:
"""Run sampling-based variance analysis for a local model.
Reruns inference N times with sampling enabled (different random seeds)
and computes variance statistics across runs to measure model stability.
Each rerun creates a modified ``LocalModelConfig`` with sampling
overrides and delegates to ``run_local_evaluation``.
Args:
local_config: LocalModelConfig with model, tokenizer, and base
generation params.
dataset: DatasetConfig with table reference and column configuration.
baseline_result: Result from greedy (run_idx=0) evaluation.
metrics: List of metrics to compute. Defaults to entity extraction
metrics.
n_reruns: Number of sampling reruns.
base_seed: Base random seed (actual seed = base_seed + k for rerun k).
temperature: Sampling temperature for variance runs.
top_p: Nucleus sampling probability for variance runs.
mlflow_experiment: MLflow experiment path.
tags: Tags for MLflow runs.
Returns:
Dict with keys:
- ``per_run_metrics``: List of dicts, one per run (including
baseline at run_idx=0)
- ``summary``: Dict with mean, std, SEM for each metric across
sampling runs
- ``baseline_metrics``: Metrics from the greedy baseline
Example:
>>> local_config = LocalModelConfig(
... name="my-model", model=model, tokenizer=tokenizer
... )
>>> dataset = DatasetConfig(
... table="catalog.schema.eval", primary_key="id",
... expected_output_column="entities_gold",
... )
>>> # First run greedy baseline
>>> baseline = run_local_evaluation(
... dataset=dataset, local_config=local_config
... )
>>> # Then run variance analysis
>>> analysis = run_local_variance_analysis(
... local_config=local_config, dataset=dataset,
... baseline_result=baseline, n_reruns=10,
... )
"""
from dataclasses import replace
import numpy as np
logger = get_logger()
logger.info(
f"Starting variance analysis: {n_reruns} reruns, "
f"temp={temperature}, top_p={top_p}, base_seed={base_seed}"
)
# Build per-run metrics list starting with baseline
per_run_metrics = [
{
"run_idx": 0,
"decode_seed": None,
**{k: float(v) for k, v in baseline_result.metrics_summary.items()},
}
]
for k in range(1, n_reruns + 1):
seed_k = base_seed + k
logger.info(f" Rerun {k}/{n_reruns}, seed={seed_k}")
# Create modified config with sampling overrides
sampling_model_config = replace(
local_model_config,
do_sample=True,
temperature=temperature,
top_p=top_p,
seed=seed_k,
)
result_k = run_local_evaluation(
dataset=dataset,
local_model_config=sampling_model_config,
metrics=metrics,
mlflow_experiment=mlflow_experiment,
run_name=f"{local_model_config.name}_run{k}_seed{seed_k}",
tags=tags,
is_nested=True,
)
per_run_metrics.append(
{
"run_idx": k,
"decode_seed": seed_k,
**{k_: float(v) for k_, v in result_k.metrics_summary.items()},
}
)
# Compute variance statistics across sampling runs (exclude baseline)
sampling_runs = [m for m in per_run_metrics if m["run_idx"] >= 1]
metric_keys = [
k
for k in sampling_runs[0].keys()
if k not in ("run_idx", "decode_seed")
and isinstance(sampling_runs[0][k], float)
]
summary: dict[str, dict[str, float]] = {}
for key in metric_keys:
values = [m[key] for m in sampling_runs if m.get(key) is not None]
if values:
arr = np.array(values)
summary[key] = {
"mean": float(arr.mean()),
"std": float(arr.std(ddof=1)) if len(arr) > 1 else 0.0,
"sem": float(arr.std(ddof=1) / np.sqrt(len(arr)))
if len(arr) > 1
else 0.0,
"min": float(arr.min()),
"max": float(arr.max()),
}
logger.info("Variance analysis complete")
for key, stats in summary.items():
if "f1" in key or "accuracy" in key:
logger.info(f" {key}: {stats['mean']:.4f} +/- {stats['sem']:.4f} (SEM)")
return {
"per_run_metrics": per_run_metrics,
"summary": summary,
"baseline_metrics": baseline_result.metrics_summary,
}