"""Configuration dataclasses for run_experiment().
This module provides typed configuration classes for running experiments
that compare multiple model configurations against a single dataset.
"""
from __future__ import annotations
from dataclasses import dataclass, field
import hashlib
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
import yipit_databricks_client as ydbc
from ml_toolkit.functions.eval_utils.constants import DEFAULT_MAX_OUTPUT_TOKENS, Metric
from ml_toolkit.functions.eval_utils.helpers.types import LLMJudgeConfig, YAMLMixin
from ml_toolkit.ops.helpers.logger import get_logger
def _convert_template_syntax(template: str) -> str:
"""Convert template syntax from {{variable}} (MLflow registry) to <<variable>> (eval_utils).
Args:
template: Template string with {{variable}} syntax.
Returns:
Template string with <<variable>> syntax.
"""
import re
# Replace {{variable}} with <<variable>>
# Handle both {{variable}} and {{ variable }} (with spaces)
pattern = r"\{\{\s*(\w+)\s*\}\}"
return re.sub(pattern, r"<<\1>>", template)
def _normalize_template_whitespace(template: str) -> str:
"""Normalize whitespace in a template: collapse newlines to single space, trim.
Prompts loaded from the registry often contain literal newlines (e.g. from
multi-line strings). This normalizes them so the template behaves like a
single-line direct prompt unless the user explicitly keeps newlines.
Args:
template: Template string (may contain \\n, \\r\\n).
Returns:
Template with newlines replaced by space and multiple spaces collapsed.
"""
import re
# Replace any newline (CR, LF, CRLF) with a space
out = template.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
# Collapse multiple spaces into one
out = re.sub(r" +", " ", out)
return out.strip()
def _load_prompt_from_registry(
prompt_registry_name: str,
prompt_version: str | int | None = None,
prompt_alias: str | None = None,
) -> tuple[str | None, str | None]:
"""Load prompt template and system prompt from MLflow Prompt Registry.
Args:
prompt_registry_name: Fully qualified prompt name (catalog.schema.prompt_name)
or URI (prompts:/catalog.schema.prompt_name@alias).
prompt_version: Specific version to load. Ignored if prompt_registry_name includes version/alias.
prompt_alias: Alias to load (e.g., 'production'). Ignored if prompt_registry_name includes alias.
Returns:
Tuple of (prompt_template, system_prompt). Both may be None if not found.
Raises:
ImportError: If prompt registry module is not available.
ValueError: If prompt cannot be loaded.
"""
logger = get_logger()
try:
from ml_toolkit.functions.llm.prompt.function import load_prompt_from_mlflow
except ImportError:
raise ImportError(
"Prompt registry integration requires ml_toolkit.functions.llm.prompt module. "
"Ensure it is available in your environment."
)
# Build URI if needed
# If prompt_registry_name is already a URI (prompts:/...), use it as-is
if prompt_registry_name.startswith("prompts:/"):
prompt_uri = prompt_registry_name
# Extract version/alias from URI if present, don't pass version parameter
# MLflow will handle version/alias from URI
version_to_pass = None
elif prompt_alias:
# Build URI with alias
prompt_uri = f"prompts:/{prompt_registry_name}@{prompt_alias}"
version_to_pass = None # Alias in URI takes precedence
elif prompt_version is not None:
# Use URI format with version to avoid MLflow trying "latest" alias
# Format: prompts:/catalog.schema.name/version
# Ensure version is properly formatted
try:
version_str = (
str(int(prompt_version))
if isinstance(prompt_version, (str, int))
else str(prompt_version)
)
except (ValueError, TypeError):
version_str = str(prompt_version)
prompt_uri = f"prompts:/{prompt_registry_name}/{version_str}"
version_to_pass = None # Version is in URI, don't pass as parameter
else:
# No version or alias specified - load latest version by name only
# Don't pass version=None as MLflow might try to use "latest" alias
prompt_uri = prompt_registry_name
version_to_pass = None
logger.info(
f"Loading prompt from registry: {prompt_uri}"
+ (f" (version param={version_to_pass})" if version_to_pass else "")
)
try:
# Load without version parameter - version/alias is in URI if needed
# This prevents MLflow from trying to use a non-existent "latest" alias
prompt_version_obj = load_prompt_from_mlflow(
name_or_uri=prompt_uri,
allow_missing=False,
)
if prompt_version_obj is None:
raise ValueError(f"Prompt '{prompt_registry_name}' not found in registry")
template = prompt_version_obj.template
# Convert template syntax from {{variable}} to <<variable>>
converted_template = _convert_template_syntax(template)
# Normalize newlines/whitespace so loaded prompts match direct-prompt style
converted_template = _normalize_template_whitespace(converted_template)
# MLflow Prompt Registry doesn't store system prompts separately,
# so system_prompt will be None
system_prompt = None
logger.info(
f"Successfully loaded prompt '{prompt_registry_name}' "
f"(version {prompt_version_obj.version})"
)
return converted_template, system_prompt
except Exception as e:
error_msg = str(e)
# Check if error is about "latest" alias not found
if "latest" in error_msg.lower() and "not found" in error_msg.lower():
raise ValueError(
f"Failed to load prompt '{prompt_registry_name}' from registry: {error_msg}\n"
f"Hint: When no version or alias is specified, MLflow tries to use a 'latest' alias.\n"
f"Either:\n"
f" 1. Specify a version: prompt_version=1\n"
f" 2. Specify an alias: prompt_alias='production'\n"
f" 3. Use URI format with alias: prompt_registry_name='prompts:/{prompt_registry_name}@production'\n"
f" 4. Create a 'latest' alias for your prompt in the registry"
) from e
else:
raise ValueError(
f"Failed to load prompt '{prompt_registry_name}' from registry: {error_msg}"
) from e
@dataclass
[docs]
class DatasetConfig(YAMLMixin):
"""Reference to an evaluation dataset.
Encapsulates all information needed to locate and use an evaluation
dataset for running evaluations.
Args:
table: Fully qualified table name (catalog.schema.table).
primary_key: Name of the primary key column.
input_column: Name of the input text column. Defaults to 'input'.
expected_output_column: Name of the expected output column for
comparison metrics. Set to None if not available.
additional_context_columns: Additional columns to include in prompt
template context.
description: Optional description of the dataset.
tags: Optional tags for the dataset.
dataset_yaml: Optional path to YAML file to load config from.
If provided, all other parameters are ignored.
Example:
>>> dataset = DatasetConfig(
... table="catalog.schema.vendor_eval_20260126",
... primary_key="row_id",
... input_column="vendor_name",
... expected_output_column="canonical_name",
... )
# Or load from YAML:
>>> dataset = DatasetConfig(dataset_yaml="path/to/dataset.yaml")
"""
# Required (have defaults to allow dataset_yaml-only construction)
table: str = ""
primary_key: str = ""
# Column configuration
input_column: str = "input"
expected_output_column: str | None = "expected_output"
additional_context_columns: list[str] | None = None
# Optional metadata
description: str | None = None
tags: dict[str, str] | None = None
# YAML loading
dataset_yaml: str | Path | None = field(default=None, repr=False)
# Cached hash value
_cached_hash: str | None = field(default=None, repr=False, compare=False)
def __post_init__(self) -> None:
"""Load from YAML if dataset_yaml is provided, then validate."""
if self.dataset_yaml is not None:
loaded = self.from_yaml(self.dataset_yaml)
object.__setattr__(self, "table", loaded.table)
object.__setattr__(self, "primary_key", loaded.primary_key)
object.__setattr__(self, "input_column", loaded.input_column)
object.__setattr__(
self, "expected_output_column", loaded.expected_output_column
)
object.__setattr__(
self, "additional_context_columns", loaded.additional_context_columns
)
object.__setattr__(self, "description", loaded.description)
object.__setattr__(self, "tags", loaded.tags)
object.__setattr__(self, "dataset_yaml", None)
# Validate required fields
if not self.table:
raise ValueError("DatasetConfig 'table' is required")
if not self.primary_key:
raise ValueError("DatasetConfig 'primary_key' is required")
def compute_dataset_hash(self, spark: "ydbc.SparkSession | None" = None) -> str:
"""Compute MD5 hash of dataset for reproducibility tracking.
The hash is computed from:
- Table name
- Row count
- Sample of primary key values
This provides a reproducible fingerprint of the dataset that changes
if the dataset content changes.
Args:
spark: Optional SparkSession. If not provided, will get from ydbc.
Returns:
MD5 hash string (32 characters).
"""
if self._cached_hash is not None:
return self._cached_hash
import yipit_databricks_client as ydbc
if spark is None:
spark = ydbc.get_spark_session()
# Get table info for hashing
df = spark.table(self.table)
row_count = df.count()
# Get sample of primary keys for content fingerprinting
# Use first 100 PKs sorted to ensure consistency
pk_sample = (
df.select(self.primary_key).orderBy(self.primary_key).limit(100).collect()
)
pk_values = [str(row[self.primary_key]) for row in pk_sample]
# Build hash input
hash_input = f"{self.table}:{row_count}:{','.join(pk_values)}"
hash_value = hashlib.md5(hash_input.encode()).hexdigest()
# Cache the result
object.__setattr__(self, "_cached_hash", hash_value)
return hash_value
@property
def dataset_hash(self) -> str:
"""Get MD5 hash of dataset for reproducibility tracking.
Note: This property computes the hash on first access if not cached.
For better control over SparkSession, use compute_dataset_hash() directly.
Returns:
MD5 hash string (32 characters).
"""
return self.compute_dataset_hash()
def to_dict(self) -> dict[str, Any]:
"""Convert dataset reference to dictionary for serialization.
Returns:
Dict representation of the dataset reference.
"""
return {
"table": self.table,
"primary_key": self.primary_key,
"input_column": self.input_column,
"expected_output_column": self.expected_output_column,
"additional_context_columns": self.additional_context_columns,
"description": self.description,
"tags": self.tags,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "DatasetConfig":
"""Create DatasetConfig from a dictionary.
Args:
data: Dictionary with config values.
Returns:
DatasetConfig instance.
Raises:
KeyError: If required fields (table, primary_key) are missing.
"""
return cls(
table=data["table"],
primary_key=data["primary_key"],
input_column=data.get("input_column", "input"),
expected_output_column=data.get(
"expected_output_column", "expected_output"
),
additional_context_columns=data.get("additional_context_columns"),
description=data.get("description"),
tags=data.get("tags"),
)
@dataclass
[docs]
class ModelConfig(YAMLMixin):
"""Configuration for a single model/prompt to evaluate.
Represents one model/prompt configuration to evaluate against a dataset.
Use with `Experiment` to compare multiple configurations.
Validation runs automatically on construction via __post_init__.
Args:
name: Unique name for this config (used in MLflow run name and results).
model_name: Unity Catalog model name (e.g., 'catalog.schema.model_name').
If provided without endpoint, the endpoint name will be automatically
inferred from the model name. Either model_name/endpoint or litellm_model
must be specified.
endpoint: Databricks Model Serving endpoint name. Optional if model_name
is provided (will be inferred). Either model_name/endpoint or litellm_model
must be specified (not both).
litellm_model: LiteLLM model identifier (e.g., 'gpt-4', 'claude-3').
Either model_name/endpoint or litellm_model must be specified (not both).
version: Optional integer version for routing to a specific
model version in multi-version serving endpoints. Only used when
endpoint is specified. Example: 1, 2, 3. Defaults to None (latest version).
prompt_template: Jinja2 template for constructing the user message.
Available variables: <<input>>, <<candidates>>, and context columns.
Use <<variable>> syntax. If prompt_registry_name is provided,
this will be overridden by the loaded template.
system_prompt: Optional system prompt for chat completion.
If prompt_registry_name is provided, this will be overridden.
prompt_registry_name: Fully qualified prompt name from MLflow Prompt Registry
(e.g., 'catalog.schema.prompt_name') or URI (e.g., 'prompts:/catalog.schema.prompt_name@alias').
If provided, loads prompt_template from registry (converts {{variable}} to <<variable>>).
prompt_version: Specific version to load from registry. Ignored if prompt_registry_name
includes version/alias or if prompt_alias is provided.
prompt_alias: Alias to load from registry (e.g., 'production'). Takes precedence over prompt_version.
max_output_tokens: Maximum output tokens for model response.
temperature: Temperature for model sampling (None uses model default).
tags: Additional tags to apply to the MLflow run for this config.
model_yaml: Optional path to YAML file to load config from.
If provided, all other parameters are ignored.
Example:
>>> from ml_toolkit.functions.eval_utils import ModelConfig
>>> # Using model_name (endpoint automatically inferred)
>>> config = ModelConfig(
... name="llama-8b-prompt-v1",
... model_name="catalog.schema.meta-llama-3-1-8b-instruct",
... prompt_template="Tag the vendor: <<input>>",
... )
>>> # Using endpoint directly
>>> config = ModelConfig(
... name="llama-8b-prompt-v1",
... endpoint="databricks-meta-llama-3-1-8b-instruct",
... prompt_template="Tag the vendor: <<input>>",
... )
>>> # Load prompt from MLflow Prompt Registry
>>> config = ModelConfig(
... name="llama-8b-registry",
... model_name="catalog.schema.meta-llama-3-1-8b-instruct",
... prompt_registry_name="mycatalog.myschema.vendor_tagging_prompt",
... prompt_alias="production", # or prompt_version=2
... )
# Or load from YAML:
>>> config = ModelConfig(model_yaml="path/to/model.yaml")
Raises:
ValueError: If validation fails (missing endpoint/model or both specified).
"""
# Required (has default to allow model_yaml-only construction)
name: str = ""
# Model/Endpoint (one required)
model_name: str | None = (
None # Unity Catalog model name - used to infer endpoint if not provided
)
endpoint: str | None = None # Optional - inferred from model_name if not provided
litellm_model: str | None = None
version: int | None = None
# Prompt configuration
prompt_template: str | None = None
system_prompt: str | None = None
# Prompt registry integration
prompt_registry_name: str | None = None
prompt_version: str | int | None = None
prompt_alias: str | None = None
# Inference parameters
max_output_tokens: int = DEFAULT_MAX_OUTPUT_TOKENS
temperature: float | None = None
# Optional metadata
tags: dict[str, str] | None = None
litellm_model_kwargs: dict[str, Any] | None = None
# YAML loading
model_yaml: str | Path | None = field(default=None, repr=False)
def __post_init__(self) -> None:
"""Load from YAML if model_yaml is provided, then load prompt from registry if needed, then validate."""
if self.model_yaml is not None:
loaded = self.from_yaml(self.model_yaml)
object.__setattr__(self, "name", loaded.name)
object.__setattr__(self, "model_name", loaded.model_name)
object.__setattr__(self, "endpoint", loaded.endpoint)
object.__setattr__(self, "litellm_model", loaded.litellm_model)
object.__setattr__(self, "version", loaded.version)
object.__setattr__(self, "prompt_template", loaded.prompt_template)
object.__setattr__(self, "system_prompt", loaded.system_prompt)
object.__setattr__(
self, "prompt_registry_name", loaded.prompt_registry_name
)
object.__setattr__(self, "prompt_version", loaded.prompt_version)
object.__setattr__(self, "prompt_alias", loaded.prompt_alias)
object.__setattr__(self, "max_output_tokens", loaded.max_output_tokens)
object.__setattr__(self, "temperature", loaded.temperature)
object.__setattr__(self, "tags", loaded.tags)
object.__setattr__(self, "model_yaml", None)
# Infer endpoint from model_name if endpoint is not provided
if self.model_name is not None and self.endpoint is None:
from ml_toolkit.functions.llm.serving_endpoints.helpers.validation import (
infer_endpoint_name_from_model,
)
inferred_endpoint = infer_endpoint_name_from_model(self.model_name)
object.__setattr__(self, "endpoint", inferred_endpoint)
logger = get_logger()
logger.info(
f"Inferred endpoint '{inferred_endpoint}' from model_name '{self.model_name}' for config '{self.name}'"
)
# Resolve version to latest if not specified and endpoint is provided
if self.endpoint is not None and self.version is None:
from ml_toolkit.functions.llm.serving_endpoints import (
get_latest_version_from_endpoint,
)
try:
latest_version = get_latest_version_from_endpoint(self.endpoint)
object.__setattr__(self, "version", latest_version)
except Exception:
# If we can't resolve the version (e.g., endpoint doesn't exist yet),
# leave it as None and let the evaluation function handle it
pass
# Store original prompt_template value before loading from registry
# This is needed for validation to check if user provided both initially
original_prompt_template = self.prompt_template
# Load prompt from registry if prompt_registry_name is provided
if self.prompt_registry_name:
logger = get_logger()
logger.info(
f"Loading prompt from registry for config '{self.name}': {self.prompt_registry_name}"
)
# Ensure prompt_version is properly formatted (convert to int if string)
prompt_version_to_use = self.prompt_version
if prompt_version_to_use is not None:
# Convert string version to int if needed (MLflow accepts both)
try:
if isinstance(prompt_version_to_use, str):
prompt_version_to_use = int(prompt_version_to_use)
except (ValueError, TypeError):
# If conversion fails, keep original value
pass
template, system_prompt = _load_prompt_from_registry(
prompt_registry_name=self.prompt_registry_name,
prompt_version=prompt_version_to_use,
prompt_alias=self.prompt_alias,
)
object.__setattr__(self, "prompt_template", template)
if system_prompt is not None:
object.__setattr__(self, "system_prompt", system_prompt)
# Validate with original prompt_template to check if user provided both initially
self._validate(original_prompt_template=original_prompt_template)
def _validate(self, original_prompt_template: str | None = None) -> None:
"""Internal validation logic.
Args:
original_prompt_template: The prompt_template value before loading from registry.
Used to check if user provided both direct prompt and registry initially.
"""
if not self.name:
raise ValueError("Config 'name' is required")
# At this point, endpoint should have been inferred from model_name if provided
# Check that we have either endpoint (with or without model_name) or litellm_model
if not self.endpoint and not self.litellm_model:
raise ValueError(
f"Config '{self.name}': either 'model_name', 'endpoint', or 'litellm_model' is required"
)
if self.endpoint and self.litellm_model:
raise ValueError(
f"Config '{self.name}': only one of 'endpoint' or 'litellm_model' "
"should be specified, not both"
)
# Validate prompt configuration: either direct prompts OR registry, not both
# Use original_prompt_template to check if user provided both initially
# (after loading from registry, prompt_template will be set, so we need original value)
# Only check original_prompt_template - don't fall back to self.prompt_template
# because that will be set after loading from registry
has_direct_prompt = original_prompt_template is not None
has_registry_prompt = self.prompt_registry_name is not None
if has_direct_prompt and has_registry_prompt:
raise ValueError(
f"Config '{self.name}': cannot specify both 'prompt_template' and "
"'prompt_registry_name'. Use either direct prompt or registry, not both."
)
# Validate that prompt_registry_name is provided if prompt_alias or prompt_version is used
if (
self.prompt_alias is not None or self.prompt_version is not None
) and not self.prompt_registry_name:
raise ValueError(
f"Config '{self.name}': 'prompt_registry_name' is required when "
"'prompt_alias' or 'prompt_version' is specified."
)
if self.prompt_registry_name and self.prompt_version and self.prompt_alias:
raise ValueError(
f"Config '{self.name}': cannot specify both 'prompt_version' and "
"'prompt_alias'. Use one or the other."
)
def to_dict(self) -> dict[str, Any]:
"""Convert config to dictionary for serialization.
Returns:
Dict representation of the config.
"""
return {
"name": self.name,
"model_name": self.model_name,
"endpoint": self.endpoint,
"litellm_model": self.litellm_model,
"version": self.version,
"prompt_template": self.prompt_template,
"system_prompt": self.system_prompt,
"prompt_registry_name": self.prompt_registry_name,
"prompt_version": self.prompt_version,
"prompt_alias": self.prompt_alias,
"max_output_tokens": self.max_output_tokens,
"temperature": self.temperature,
"tags": self.tags,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ModelConfig":
"""Create ModelConfig from a dictionary.
Args:
data: Dictionary with config values.
Returns:
ModelConfig instance.
Raises:
KeyError: If required field (name) is missing.
ValueError: If validation fails.
"""
return cls(
name=data["name"],
model_name=data.get("model_name"),
endpoint=data.get("endpoint"),
litellm_model=data.get("litellm_model"),
version=data.get("version"),
prompt_template=data.get("prompt_template"),
system_prompt=data.get("system_prompt"),
prompt_registry_name=data.get("prompt_registry_name"),
prompt_version=data.get("prompt_version"),
prompt_alias=data.get("prompt_alias"),
max_output_tokens=data.get("max_output_tokens", DEFAULT_MAX_OUTPUT_TOKENS),
temperature=data.get("temperature"),
tags=data.get("tags"),
)
@dataclass
[docs]
class Experiment(YAMLMixin):
"""Bundles a dataset, model configurations, and metrics together.
An Experiment defines a complete evaluation scenario: what dataset to use,
which models to compare, and which metrics to compute. Metrics are defined
once and applied to ALL models, eliminating duplication.
Args:
name: Unique name for this experiment.
dataset: DatasetConfig reference to the evaluation dataset.
models: List of ModelConfig objects to evaluate.
metrics: List of metrics to compute for all models. Can include Metric
enum values, string names, or custom scorer callables.
llm_judge_config: Configuration for LLM-as-judge metric. Required
if 'llm_judge' is in metrics list.
description: Optional description of the experiment.
tags: Optional tags for the experiment.
experiment_yaml: Optional path to YAML file to load config from.
If provided, all other parameters are ignored.
Example:
>>> from ml_toolkit.functions.eval_utils import (
... DatasetConfig, ModelConfig, Experiment, Metric, LLMJudgeConfig
... )
>>> dataset = DatasetConfig(
... table="catalog.schema.vendor_eval",
... primary_key="row_id",
... input_column="vendor_name",
... expected_output_column="canonical_name",
... )
>>> experiment = Experiment(
... name="vendor-tagger-comparison",
... dataset=dataset,
... models=[
... ModelConfig(name="llama-8b", endpoint="llama-endpoint"),
... ModelConfig(name="gpt-4o", litellm_model="gpt-4o"),
... ],
... metrics=[Metric.LATENCY, Metric.EXACT_MATCH, Metric.LLM_JUDGE],
... llm_judge_config=LLMJudgeConfig(criteria="correctness"),
... )
# Or load from YAML:
>>> experiment = Experiment(experiment_yaml="path/to/config.yaml")
Raises:
ValueError: If validation fails (no models, duplicate names, or
llm_judge metric without config).
"""
# Required (have defaults to allow experiment_yaml-only construction)
name: str = ""
dataset: DatasetConfig = None
models: list[ModelConfig] = field(default_factory=list)
# Metrics applied to ALL models
metrics: list[str | Metric | Callable] = field(
default_factory=lambda: [Metric.LATENCY, Metric.EXACT_MATCH]
)
llm_judge_config: LLMJudgeConfig | None = None
# Metadata
description: str | None = None
tags: dict[str, str] | None = None
# YAML loading
experiment_yaml: str | Path | None = field(default=None, repr=False)
def __post_init__(self) -> None:
"""Load from YAML if experiment_yaml is provided, then validate."""
if self.experiment_yaml is not None:
loaded = self.from_yaml(self.experiment_yaml)
object.__setattr__(self, "name", loaded.name)
object.__setattr__(self, "dataset", loaded.dataset)
object.__setattr__(self, "models", loaded.models)
object.__setattr__(self, "metrics", loaded.metrics)
object.__setattr__(self, "llm_judge_config", loaded.llm_judge_config)
object.__setattr__(self, "description", loaded.description)
object.__setattr__(self, "tags", loaded.tags)
object.__setattr__(self, "experiment_yaml", None)
self._validate()
def _validate(self) -> None:
"""Internal validation logic."""
if not self.name:
raise ValueError("Experiment 'name' is required")
if self.dataset is None:
raise ValueError(f"Experiment '{self.name}': 'dataset' is required")
if not self.models:
raise ValueError(
f"Experiment '{self.name}': at least one ModelConfig must be provided"
)
# Check for duplicate model names
names = [m.name for m in self.models]
if len(names) != len(set(names)):
duplicates = [n for n in names if names.count(n) > 1]
raise ValueError(
f"Experiment '{self.name}': model names must be unique. "
f"Duplicates found: {set(duplicates)}"
)
# Check for llm_judge metric (handle both string and Metric enum)
has_llm_judge = any(
str(m) == Metric.LLM_JUDGE.value for m in self.metrics if not callable(m)
)
if has_llm_judge and self.llm_judge_config is None:
raise ValueError(
f"Experiment '{self.name}': 'llm_judge_config' is required when "
"'llm_judge' metric is specified"
)
def to_dict(self) -> dict[str, Any]:
"""Convert experiment to dictionary for serialization.
Returns:
Dict representation of the experiment (excludes callable metrics).
"""
# Convert metrics to strings (handles both str and Metric enum)
# Filter out callable metrics which can't be serialized
serializable_metrics = [str(m) for m in self.metrics if not callable(m)]
result = {
"name": self.name,
"dataset": self.dataset.to_dict(),
"models": [m.to_dict() for m in self.models],
"metrics": serializable_metrics,
"description": self.description,
"tags": self.tags,
}
if self.llm_judge_config:
result["llm_judge_config"] = self.llm_judge_config.to_dict()
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Experiment":
"""Create Experiment from a dictionary.
Parses nested dataset, models, and llm_judge_config from dicts.
Args:
data: Dictionary with config values.
Returns:
Experiment instance.
Raises:
KeyError: If required fields are missing.
ValueError: If validation fails.
"""
# Parse nested dataset config
dataset = DatasetConfig.from_dict(data["dataset"])
# Parse nested model configs
models = [ModelConfig.from_dict(m) for m in data["models"]]
# Parse optional llm_judge_config
llm_judge_config = None
if data.get("llm_judge_config"):
llm_judge_config = LLMJudgeConfig.from_dict(data["llm_judge_config"])
# Metrics work as strings since Metric inherits from str
metrics = data.get("metrics", [Metric.LATENCY, Metric.EXACT_MATCH])
return cls(
name=data["name"],
dataset=dataset,
models=models,
metrics=metrics,
llm_judge_config=llm_judge_config,
description=data.get("description"),
tags=data.get("tags"),
)