"""Multi-model comparison helpers.
Provides functions for pivoting multi-model UDTF output and computing
pairwise / majority agreement metrics (strict and relaxed) across models.
Usage:
>>> from ml_toolkit.functions.eval_utils import (
... pivot_and_compare_results,
... add_comparison_metrics,
... normalized_output_array,
... )
"""
from itertools import combinations
import re
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_RE_LEADING_FENCE = re.compile(r"^```\w*\s*\n?")
_RE_TRAILING_FENCE = re.compile(r"\n?```\s*$")
def _clean_markdown(text: str) -> str:
"""Strip markdown code fences from a string."""
if not text:
return text
text = _RE_LEADING_FENCE.sub("", text.strip())
text = _RE_TRAILING_FENCE.sub("", text.strip())
return text.strip()
_clean_markdown_udf = F.udf(_clean_markdown, T.StringType())
# Schema for parsing a single model-output JSON object
_MODEL_OUTPUT_SCHEMA = T.StructType(
[
T.StructField("model", T.StringType(), True),
T.StructField("output", T.StringType(), True),
]
)
[docs]
def normalized_output_array(colname: str):
"""Build a Spark Column that normalizes a model-output JSON string to ``array<string>``.
Expects *colname* to contain JSON of the form
``{"model": "...", "output": "[\"a\",\"b\"]"}``. The function:
1. Parses *colname* as ``struct<model, output>``.
2. Parses the ``output`` field as ``array<string>``.
3. Applies ``lower(trim(x))`` to each element.
4. Filters out null and empty strings.
5. Returns ``array_distinct(...)`` of the result.
:param colname: Name of the column containing the JSON string.
:returns: A PySpark Column expression.
"""
return F.expr(
f"""
array_distinct(
filter(
transform(
from_json(
from_json({colname}, '{_MODEL_OUTPUT_SCHEMA.simpleString()}').output,
'array<string>'
),
x -> lower(trim(x))
),
x -> x is not null and x != ''
)
)
"""
)
# ---------------------------------------------------------------------------
# Comparison metrics
# ---------------------------------------------------------------------------
[docs]
def add_comparison_metrics(df: DataFrame, models: list[str]) -> DataFrame:
"""Add a ``metrics`` map column with pairwise and majority agreement scores.
Expects *df* to have columns ``m0``, ``m1``, ... ``m{n-1}`` where each is
an ``array<string>`` of normalized outputs (e.g. from :func:`normalized_output_array`).
The returned DataFrame has only the original columns plus a single
``metrics`` column of type ``map<string, double>`` with keys:
- ``all_models_strict``: 1.0 if every pairwise comparison is an exact set match.
- ``all_models_relaxed``: 1.0 if every pairwise comparison satisfies subset-or-superset.
- ``majority_strict``: 1.0 if a strict majority of outputs are identical.
- ``majority_relaxed``: 1.0 if a relaxed majority of outputs agree.
:param df: DataFrame with ``m0`` .. ``m{n-1}`` array columns.
:param models: List of model names (only ``len(models)`` matters).
:returns: DataFrame with original columns + ``metrics``.
"""
n = len(models)
original_cols = df.columns
if n < 2:
# Not enough models to compare — return null metrics
df = df.withColumn(
"metrics",
F.create_map(
F.lit("all_models_strict"),
F.lit(None).cast("double"),
F.lit("all_models_relaxed"),
F.lit(None).cast("double"),
F.lit("majority_strict"),
F.lit(None).cast("double"),
F.lit("majority_relaxed"),
F.lit(None).cast("double"),
),
)
return df.select(*original_cols, "metrics")
# --- Pairwise strict: exact set equality (order-insensitive) ---
for i in range(n):
df = df.withColumn(f"m{i}_sorted", F.array_sort(F.col(f"m{i}")))
strict_cols = []
relaxed_cols = []
for i, j in combinations(range(n), 2):
strict_col = f"m{i}{j}_strict"
df = df.withColumn(
strict_col,
(F.col(f"m{i}_sorted") == F.col(f"m{j}_sorted")).cast("double"),
)
strict_cols.append(strict_col)
# Relaxed: one is a subset of the other
minus_ij = f"m{i}_minus_m{j}"
minus_ji = f"m{j}_minus_m{i}"
relaxed_col = f"m{i}{j}_relaxed"
df = (
df.withColumn(minus_ij, F.array_except(F.col(f"m{i}"), F.col(f"m{j}")))
.withColumn(minus_ji, F.array_except(F.col(f"m{j}"), F.col(f"m{i}")))
.withColumn(
relaxed_col,
((F.size(F.col(minus_ij)) == 0) | (F.size(F.col(minus_ji)) == 0)).cast(
"double"
),
)
)
relaxed_cols.append(relaxed_col)
num_pairs = len(list(combinations(range(n), 2)))
# --- All-models agreement ---
all_strict_sum = sum(F.col(c) for c in strict_cols)
all_relaxed_sum = sum(F.col(c) for c in relaxed_cols)
df = df.withColumn(
"all_models_strict",
F.when(all_strict_sum == num_pairs, F.lit(1.0)).otherwise(F.lit(0.0)),
)
df = df.withColumn(
"all_models_relaxed",
F.when(all_relaxed_sum == num_pairs, F.lit(1.0)).otherwise(F.lit(0.0)),
)
# --- Majority agreement ---
# Collect non-null model outputs into an array of arrays
non_null_outputs = F.array(*[F.col(f"m{i}") for i in range(n)])
df = df.withColumn("_non_null_outputs", non_null_outputs)
# Use ceil(n/2) as majority threshold
majority_threshold = (n // 2) + 1
# Build candidates: distinct sorted arrays
df = df.withColumn(
"_candidates",
F.array_distinct(
F.transform(F.col("_non_null_outputs"), lambda x: F.array_sort(x))
),
)
# Strict majority: count how many outputs match each candidate (exact sorted equality)
df = df.withColumn(
"_strict_counts",
F.transform(
F.col("_candidates"),
lambda candidate: F.aggregate(
F.col("_non_null_outputs"),
F.lit(0),
lambda acc, x: acc
+ F.when(F.array_sort(x) == candidate, F.lit(1)).otherwise(F.lit(0)),
),
),
)
df = df.withColumn(
"majority_strict",
F.when(
F.array_max(F.col("_strict_counts")) >= majority_threshold, F.lit(1.0)
).otherwise(F.lit(0.0)),
)
# Relaxed majority: for each candidate, count how many outputs are
# subset-or-superset of that candidate
df = df.withColumn(
"_relaxed_counts",
F.transform(
F.col("_candidates"),
lambda candidate: F.aggregate(
F.col("_non_null_outputs"),
F.lit(0),
lambda acc, x: acc
+ F.when(
(F.size(F.array_except(x, candidate)) == 0)
| (F.size(F.array_except(candidate, x)) == 0),
F.lit(1),
).otherwise(F.lit(0)),
),
),
)
df = df.withColumn(
"majority_relaxed",
F.when(
F.array_max(F.col("_relaxed_counts")) >= majority_threshold, F.lit(1.0)
).otherwise(F.lit(0.0)),
)
# --- Assemble metrics map ---
df = df.withColumn(
"metrics",
F.create_map(
F.lit("all_models_strict"),
F.col("all_models_strict"),
F.lit("all_models_relaxed"),
F.col("all_models_relaxed"),
F.lit("majority_strict"),
F.col("majority_strict"),
F.lit("majority_relaxed"),
F.col("majority_relaxed"),
),
)
return df.select(*original_cols, "metrics")
# ---------------------------------------------------------------------------
# Pivot and compare
# ---------------------------------------------------------------------------
[docs]
def pivot_and_compare_results(results_df: DataFrame, models: list[str]) -> DataFrame:
"""Pivot multi-model UDTF output and add comparison metrics.
Takes the long-format output from the multi-model UDTF (one row per
model per input) and produces a wide-format DataFrame with one row per
input, one column per model, and a ``metrics`` map.
:param results_df: DataFrame with columns ``model``, ``output``,
``parameters``, ``error``, ``raw_output`` (as produced by the
multi-model inference UDTF).
:param models: Ordered list of model names (must match the values in
the ``model`` column).
:returns: DataFrame with ``row_id``, ``parameters``,
``model_0`` .. ``model_{n-1}``, and ``metrics``.
"""
# 1. Clean output (safety net for cached results)
df = results_df.withColumn("output", _clean_markdown_udf(F.col("output")))
# 2. Parse parameters and derive row_id
df = df.withColumn(
"_params_parsed", F.from_json(F.col("parameters"), "map<string,string>")
)
df = df.withColumn(
"row_id",
F.coalesce(
F.col("_params_parsed").getItem("input_uuid"), F.md5(F.col("parameters"))
),
)
# 3. Group by row_id, collect model outputs into a map
df = df.groupBy("row_id", "parameters").agg(
F.map_from_entries(
F.collect_list(F.struct(F.col("model"), F.col("output")))
).alias("model_outputs")
)
# 4. Create per-model columns (JSON string with model + output)
for i, model in enumerate(models):
df = df.withColumn(
f"model_{i}",
F.to_json(
F.struct(
F.lit(model).alias("model"),
F.col("model_outputs").getItem(model).alias("output"),
)
),
)
# 5. Normalize each model column to array<string>
for i in range(len(models)):
df = df.withColumn(f"m{i}", normalized_output_array(f"model_{i}"))
# 6. Add comparison metrics
df = add_comparison_metrics(df, models)
# 7. Select final columns: row_id, parameters, model_0..model_n-1, metrics
final_cols = (
["row_id", "parameters"]
+ [f"model_{i}" for i in range(len(models))]
+ ["metrics"]
)
return df.select(*final_cols)