"""Public API for eval table_name management.
This module provides functions to create, update, and manage evaluation tables
in Unity Catalog.
"""
from pyspark.sql import DataFrame
import yipit_databricks_client as ydbc
from yipit_databricks_utils.helpers.delta import create_table
from ml_toolkit.functions.eval_utils.constants import EVAL_CATALOG
from ml_toolkit.functions.eval_utils.helpers.table_management import (
add_tags_to_table,
generate_timestamped_table_name,
get_base_table_name,
get_table_lock_info,
remove_table_lock,
set_table_lock,
validate_eval_dataframe,
validate_schema_name,
)
from ml_toolkit.ops.helpers.exceptions import (
AggregateExceptionHandler,
MLOpsToolkitEvalTableLockedException,
MLOpsToolkitSchemaNotFoundException,
)
from ml_toolkit.ops.helpers.logger import get_logger, suppress_table_creation_logs
from ml_toolkit.ops.helpers.validation import assert_table_exists, schema_exists
[docs]
def create_eval_table(
schema_name: str,
table_name: str,
primary_key: str,
df: DataFrame,
*,
catalog_name: str = EVAL_CATALOG,
additional_columns: dict[str, str] | None = None,
description: str | None = None,
tags: dict[str, str] | None = None,
) -> str:
"""Create a new evaluation dataset table_name in Unity Catalog.
Creates a Delta table_name with the standard eval schema_name plus any additional
columns. The table_name name will be suffixed with a datetime stamp. The schema_name
name must end with '_dataset_bronze'.
Args:
schema_name: Schema name (must end with '_dataset_bronze').
table_name: Base table_name name (datetime suffix will be appended).
primary_key: Name of the primary key column in the DataFrame.
df: Spark DataFrame containing the evaluation data. Must include
'input' column and the specified primary key column.
catalog_name: Unity Catalog name. Defaults to 'yd_tagging_platform_evals'.
additional_columns: Optional dict mapping column names to SQL types
for additional columns beyond the standard schema_name.
description: Optional table_name description for UC metadata.
tags: Optional dict of UC tags to apply to the table_name.
Returns:
Fully qualified table_name name: '{catalog_name}.{schema_name}.{table_name}_{timestamp}'
Raises:
MLOpsToolkitInvalidSchemaNameException: If schema_name doesn't end with '_dataset_bronze'.
MLOpsToolkitColumnNotFoundException: If DataFrame is missing 'input' or primary key column.
MLOpsToolkitDuplicatePrimaryKeyException: If primary key column contains duplicates.
MLOpsToolkitSchemaNotFoundException: If schema_name doesn't exist.
Example:
>>> df = spark.createDataFrame([
... {"row_id": "1", "input": "acme corp", "candidates": ["Acme", "ACME Corp"]},
... {"row_id": "2", "input": "widgets inc", "candidates": ["Widgets", "Widgets Inc"]},
... ])
>>> table_name = create_eval_table(
... schema_name="vendor_tagging_dataset_bronze",
... table_name="vendor_eval",
... primary_key="row_id",
... df=df,
... description="Vendor name tagging evaluation set v1",
... tags={"domain": "vendor", "version": "1.0"}
... )
>>> print(table_name)
'yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052'
"""
logger = get_logger()
validate_eval_dataframe(df, primary_key)
# Validate inputs
with AggregateExceptionHandler() as exc_handler:
exc_handler.collect(validate_schema_name, schema_name)
exc_handler.collect(validate_eval_dataframe, df, primary_key)
schema_full_name = f"{catalog_name}.{schema_name}"
if not schema_exists(schema_full_name):
exc_handler.collect_raise(
MLOpsToolkitSchemaNotFoundException, schema_full_name
)
exc_handler.raise_if_any()
# Generate timestamped table_name name
timestamped_name = generate_timestamped_table_name(table_name)
full_table_name = f"{catalog_name}.{schema_name}.{timestamped_name}"
logger.info(f"Creating eval table_name: {full_table_name}")
# Create the table_name
with suppress_table_creation_logs():
create_table(
schema_name=schema_name,
table_name=timestamped_name,
query=df,
catalog_name=catalog_name,
overwrite=False,
table_comment=description,
spark_options={"mergeSchema": "true"},
)
add_tags_to_table(full_table_name, tags)
logger.info(f"Created eval table_name: {full_table_name} with {df.count()} rows")
return full_table_name
[docs]
def upsert_eval_table(
table_name: str,
df: "DataFrame",
primary_key: str,
*,
create_new_version: bool = True,
) -> str:
"""Upsert rows into an existing evaluation table_name or create a new version.
If create_new_version is True (default), creates a new timestamped table_name
with the combined data. If False, performs an in-place MERGE operation
on the existing table_name (requires table_name to be unlocked).
Args:
table_name: Fully qualified table_name name (catalog_name.schema_name.table_name).
df: Spark DataFrame containing rows to upsert. Must include the
primary key column and 'input' column.
primary_key: Name of the primary key column for merge matching.
create_new_version: If True, create a new timestamped table_name with
merged data. If False, update the existing table_name in place.
Returns:
Fully qualified table_name name (new name if versioned, same if in-place).
Raises:
MLOpsToolkitTableNotFoundException: If table_name doesn't exist.
MLOpsToolkitEvalTableLockedException: If table_name is locked and create_new_version is False.
MLOpsToolkitColumnNotFoundException: If DataFrame schema_name doesn't match existing table_name.
Example:
>>> new_rows = spark.createDataFrame([
... {"row_id": "3", "input": "new vendor", "candidates": ["New Vendor"]}
... ])
>>> new_table = upsert_eval_table(
... table_name="yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052",
... df=new_rows,
... primary_key="row_id",
... create_new_version=True
... )
"""
logger = get_logger()
spark = ydbc.get_spark_session()
assert_table_exists(table_name)
validate_eval_dataframe(df, primary_key)
if create_new_version:
existing_df = spark.table(table_name)
merged_df = (
existing_df.alias("existing")
.join(df.alias("new"), on=primary_key, how="left_anti")
.unionByName(df, allowMissingColumns=True)
)
parts = table_name.split(".")
catalog_name, schema_name, table_name = parts[0], parts[1], parts[2]
base_name = get_base_table_name(table_name)
# Create new versioned table_name
timestamped_name = generate_timestamped_table_name(base_name)
new_full_name = f"{catalog_name}.{schema_name}.{timestamped_name}"
with suppress_table_creation_logs():
create_table(
schema_name=schema_name,
table_name=timestamped_name,
query=merged_df,
catalog_name=catalog_name,
overwrite=False,
spark_options={"mergeSchema": "true"},
)
logger.info(
f"Created new version: {new_full_name} with {merged_df.count()} rows"
)
return new_full_name
else:
# In-place update - check lock first
lock_info = get_table_lock_info(table_name)
if lock_info is not None:
raise MLOpsToolkitEvalTableLockedException(
table_name,
locked_by=lock_info.get("locked_by"),
reason=lock_info.get("reason"),
)
# Perform MERGE operation
df.createOrReplaceTempView("upsert_data")
merge_sql = f"""
MERGE INTO {table_name} AS target
USING upsert_data AS source
ON target.{primary_key} = source.{primary_key}
WHEN MATCHED THEN UPDATE SET *
WHEN NOT MATCHED THEN INSERT *
"""
spark.sql(merge_sql)
logger.info(f"Updated table_name in place: {table_name}")
return table_name
[docs]
def lock_eval_table(
table_name: str,
*,
reason: str | None = None,
locked_by: str | None = None,
) -> None:
"""Lock an evaluation table_name to prevent modifications.
Applies a UC tag to mark the table_name as locked. Locked tables cannot be
modified via upsert_eval_table (in-place mode) or deleted.
Args:
table_name: Fully qualified table_name name (catalog_name.schema_name.table_name).
reason: Optional reason for locking (stored in tag value).
locked_by: Optional identifier of who locked the table_name.
Defaults to current user from Spark session.
Raises:
MLOpsToolkitTableNotFoundException: If table_name doesn't exist.
MLOpsToolkitEvalTableLockedException: If table_name is already locked.
Example:
>>> lock_eval_table(
... table_name="yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052",
... reason="Production eval set - do not modify",
... locked_by="ml-team"
... )
Notes:
Lock is implemented via UC tags:
- Tag key: 'eval_locked'
- Tag value: JSON with 'locked_at', 'locked_by', 'reason'
"""
logger = get_logger()
assert_table_exists(table_name)
# Get current user if not provided
if locked_by is None:
spark = ydbc.get_spark_session()
try:
locked_by = spark.sql("SELECT current_user()").collect()[0][0]
except Exception:
locked_by = "unknown"
set_table_lock(table_name, locked_by=locked_by, reason=reason)
logger.info(
f"Locked eval table_name: {table_name} (by: {locked_by}, reason: {reason})"
)
[docs]
def unlock_eval_table(
table_name: str,
*,
force: bool = False,
) -> None:
"""Unlock a previously locked evaluation table_name.
Removes the lock tag from the table_name, allowing modifications.
Args:
table_name: Fully qualified table_name name (catalog_name.schema_name.table_name).
force: If True, unlock even if locked by a different user.
Defaults to False.
Raises:
MLOpsToolkitTableNotFoundException: If table_name doesn't exist.
ValueError: If table_name is not locked.
PermissionError: If table_name was locked by different user and force=False.
Example:
>>> unlock_eval_table(
... table_name="yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052"
... )
"""
logger = get_logger()
assert_table_exists(table_name)
# Get current user for permission check
current_user = None
if not force:
spark = ydbc.get_spark_session()
try:
current_user = spark.sql("SELECT current_user()").collect()[0][0]
except Exception:
pass
remove_table_lock(table_name, force=force, current_user=current_user)
logger.info(f"Unlocked eval table_name: {table_name}")