Source code for ml_toolkit.functions.eval_utils.helpers.multi_model_comparison

"""Multi-model comparison helpers.

Provides functions for pivoting multi-model UDTF output and computing
pairwise / majority agreement metrics (strict and relaxed) across models.

Usage:
    >>> from ml_toolkit.functions.eval_utils import (
    ...     pivot_and_compare_results,
    ...     add_comparison_metrics,
    ...     normalized_output_array,
    ... )
"""

from itertools import combinations
import re

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_RE_LEADING_FENCE = re.compile(r"^```\w*\s*\n?")
_RE_TRAILING_FENCE = re.compile(r"\n?```\s*$")


def _clean_markdown(text: str) -> str:
    """Strip markdown code fences from a string."""
    if not text:
        return text
    text = _RE_LEADING_FENCE.sub("", text.strip())
    text = _RE_TRAILING_FENCE.sub("", text.strip())
    return text.strip()


_clean_markdown_udf = F.udf(_clean_markdown, T.StringType())

# Schema for parsing a single model-output JSON object
_MODEL_OUTPUT_SCHEMA = T.StructType(
    [
        T.StructField("model", T.StringType(), True),
        T.StructField("output", T.StringType(), True),
    ]
)


[docs] def normalized_output_array(colname: str): """Build a Spark Column that normalizes a model-output JSON string to ``array<string>``. Expects *colname* to contain JSON of the form ``{"model": "...", "output": "[\"a\",\"b\"]"}``. The function: 1. Parses *colname* as ``struct<model, output>``. 2. Parses the ``output`` field as ``array<string>``. 3. Applies ``lower(trim(x))`` to each element. 4. Filters out null and empty strings. 5. Returns ``array_distinct(...)`` of the result. :param colname: Name of the column containing the JSON string. :returns: A PySpark Column expression. """ return F.expr( f""" array_distinct( filter( transform( from_json( from_json({colname}, '{_MODEL_OUTPUT_SCHEMA.simpleString()}').output, 'array<string>' ), x -> lower(trim(x)) ), x -> x is not null and x != '' ) ) """ )
# --------------------------------------------------------------------------- # Comparison metrics # ---------------------------------------------------------------------------
[docs] def add_comparison_metrics(df: DataFrame, models: list[str]) -> DataFrame: """Add a ``metrics`` map column with pairwise and majority agreement scores. Expects *df* to have columns ``m0``, ``m1``, ... ``m{n-1}`` where each is an ``array<string>`` of normalized outputs (e.g. from :func:`normalized_output_array`). The returned DataFrame has only the original columns plus a single ``metrics`` column of type ``map<string, double>`` with keys: - ``all_models_strict``: 1.0 if every pairwise comparison is an exact set match. - ``all_models_relaxed``: 1.0 if every pairwise comparison satisfies subset-or-superset. - ``majority_strict``: 1.0 if a strict majority of outputs are identical. - ``majority_relaxed``: 1.0 if a relaxed majority of outputs agree. :param df: DataFrame with ``m0`` .. ``m{n-1}`` array columns. :param models: List of model names (only ``len(models)`` matters). :returns: DataFrame with original columns + ``metrics``. """ n = len(models) original_cols = df.columns if n < 2: # Not enough models to compare — return null metrics df = df.withColumn( "metrics", F.create_map( F.lit("all_models_strict"), F.lit(None).cast("double"), F.lit("all_models_relaxed"), F.lit(None).cast("double"), F.lit("majority_strict"), F.lit(None).cast("double"), F.lit("majority_relaxed"), F.lit(None).cast("double"), ), ) return df.select(*original_cols, "metrics") # --- Pairwise strict: exact set equality (order-insensitive) --- for i in range(n): df = df.withColumn(f"m{i}_sorted", F.array_sort(F.col(f"m{i}"))) strict_cols = [] relaxed_cols = [] for i, j in combinations(range(n), 2): strict_col = f"m{i}{j}_strict" df = df.withColumn( strict_col, (F.col(f"m{i}_sorted") == F.col(f"m{j}_sorted")).cast("double"), ) strict_cols.append(strict_col) # Relaxed: one is a subset of the other minus_ij = f"m{i}_minus_m{j}" minus_ji = f"m{j}_minus_m{i}" relaxed_col = f"m{i}{j}_relaxed" df = ( df.withColumn(minus_ij, F.array_except(F.col(f"m{i}"), F.col(f"m{j}"))) .withColumn(minus_ji, F.array_except(F.col(f"m{j}"), F.col(f"m{i}"))) .withColumn( relaxed_col, ((F.size(F.col(minus_ij)) == 0) | (F.size(F.col(minus_ji)) == 0)).cast( "double" ), ) ) relaxed_cols.append(relaxed_col) num_pairs = len(list(combinations(range(n), 2))) # --- All-models agreement --- all_strict_sum = sum(F.col(c) for c in strict_cols) all_relaxed_sum = sum(F.col(c) for c in relaxed_cols) df = df.withColumn( "all_models_strict", F.when(all_strict_sum == num_pairs, F.lit(1.0)).otherwise(F.lit(0.0)), ) df = df.withColumn( "all_models_relaxed", F.when(all_relaxed_sum == num_pairs, F.lit(1.0)).otherwise(F.lit(0.0)), ) # --- Majority agreement --- # Collect non-null model outputs into an array of arrays non_null_outputs = F.array(*[F.col(f"m{i}") for i in range(n)]) df = df.withColumn("_non_null_outputs", non_null_outputs) # Use ceil(n/2) as majority threshold majority_threshold = (n // 2) + 1 # Build candidates: distinct sorted arrays df = df.withColumn( "_candidates", F.array_distinct( F.transform(F.col("_non_null_outputs"), lambda x: F.array_sort(x)) ), ) # Strict majority: count how many outputs match each candidate (exact sorted equality) df = df.withColumn( "_strict_counts", F.transform( F.col("_candidates"), lambda candidate: F.aggregate( F.col("_non_null_outputs"), F.lit(0), lambda acc, x: acc + F.when(F.array_sort(x) == candidate, F.lit(1)).otherwise(F.lit(0)), ), ), ) df = df.withColumn( "majority_strict", F.when( F.array_max(F.col("_strict_counts")) >= majority_threshold, F.lit(1.0) ).otherwise(F.lit(0.0)), ) # Relaxed majority: for each candidate, count how many outputs are # subset-or-superset of that candidate df = df.withColumn( "_relaxed_counts", F.transform( F.col("_candidates"), lambda candidate: F.aggregate( F.col("_non_null_outputs"), F.lit(0), lambda acc, x: acc + F.when( (F.size(F.array_except(x, candidate)) == 0) | (F.size(F.array_except(candidate, x)) == 0), F.lit(1), ).otherwise(F.lit(0)), ), ), ) df = df.withColumn( "majority_relaxed", F.when( F.array_max(F.col("_relaxed_counts")) >= majority_threshold, F.lit(1.0) ).otherwise(F.lit(0.0)), ) # --- Assemble metrics map --- df = df.withColumn( "metrics", F.create_map( F.lit("all_models_strict"), F.col("all_models_strict"), F.lit("all_models_relaxed"), F.col("all_models_relaxed"), F.lit("majority_strict"), F.col("majority_strict"), F.lit("majority_relaxed"), F.col("majority_relaxed"), ), ) return df.select(*original_cols, "metrics")
# --------------------------------------------------------------------------- # Pivot and compare # ---------------------------------------------------------------------------
[docs] def pivot_and_compare_results(results_df: DataFrame, models: list[str]) -> DataFrame: """Pivot multi-model UDTF output and add comparison metrics. Takes the long-format output from the multi-model UDTF (one row per model per input) and produces a wide-format DataFrame with one row per input, one column per model, and a ``metrics`` map. :param results_df: DataFrame with columns ``model``, ``output``, ``parameters``, ``error``, ``raw_output`` (as produced by the multi-model inference UDTF). :param models: Ordered list of model names (must match the values in the ``model`` column). :returns: DataFrame with ``row_id``, ``parameters``, ``model_0`` .. ``model_{n-1}``, and ``metrics``. """ # 1. Clean output (safety net for cached results) df = results_df.withColumn("output", _clean_markdown_udf(F.col("output"))) # 2. Parse parameters and derive row_id df = df.withColumn( "_params_parsed", F.from_json(F.col("parameters"), "map<string,string>") ) df = df.withColumn( "row_id", F.coalesce( F.col("_params_parsed").getItem("input_uuid"), F.md5(F.col("parameters")) ), ) # 3. Group by row_id, collect model outputs into a map df = df.groupBy("row_id", "parameters").agg( F.map_from_entries( F.collect_list(F.struct(F.col("model"), F.col("output"))) ).alias("model_outputs") ) # 4. Create per-model columns (JSON string with model + output) for i, model in enumerate(models): df = df.withColumn( f"model_{i}", F.to_json( F.struct( F.lit(model).alias("model"), F.col("model_outputs").getItem(model).alias("output"), ) ), ) # 5. Normalize each model column to array<string> for i in range(len(models)): df = df.withColumn(f"m{i}", normalized_output_array(f"model_{i}")) # 6. Add comparison metrics df = add_comparison_metrics(df, models) # 7. Select final columns: row_id, parameters, model_0..model_n-1, metrics final_cols = ( ["row_id", "parameters"] + [f"model_{i}" for i in range(len(models))] + ["metrics"] ) return df.select(*final_cols)