Source code for ml_toolkit.functions.eval_utils.table_management

"""Public API for eval table_name management.

This module provides functions to create, update, and manage evaluation tables
in Unity Catalog.
"""

from pyspark.sql import DataFrame
import yipit_databricks_client as ydbc
from yipit_databricks_utils.helpers.delta import create_table

from ml_toolkit.functions.eval_utils.constants import EVAL_CATALOG
from ml_toolkit.functions.eval_utils.helpers.table_management import (
    add_tags_to_table,
    generate_timestamped_table_name,
    get_base_table_name,
    get_table_lock_info,
    remove_table_lock,
    set_table_lock,
    validate_eval_dataframe,
    validate_schema_name,
)
from ml_toolkit.ops.helpers.exceptions import (
    AggregateExceptionHandler,
    MLOpsToolkitEvalTableLockedException,
    MLOpsToolkitSchemaNotFoundException,
)
from ml_toolkit.ops.helpers.logger import get_logger, suppress_table_creation_logs
from ml_toolkit.ops.helpers.validation import assert_table_exists, schema_exists


[docs] def create_eval_table( schema_name: str, table_name: str, primary_key: str, df: DataFrame, *, catalog_name: str = EVAL_CATALOG, additional_columns: dict[str, str] | None = None, description: str | None = None, tags: dict[str, str] | None = None, ) -> str: """Create a new evaluation dataset table_name in Unity Catalog. Creates a Delta table_name with the standard eval schema_name plus any additional columns. The table_name name will be suffixed with a datetime stamp. The schema_name name must end with '_dataset_bronze'. Args: schema_name: Schema name (must end with '_dataset_bronze'). table_name: Base table_name name (datetime suffix will be appended). primary_key: Name of the primary key column in the DataFrame. df: Spark DataFrame containing the evaluation data. Must include 'input' column and the specified primary key column. catalog_name: Unity Catalog name. Defaults to 'yd_tagging_platform_evals'. additional_columns: Optional dict mapping column names to SQL types for additional columns beyond the standard schema_name. description: Optional table_name description for UC metadata. tags: Optional dict of UC tags to apply to the table_name. Returns: Fully qualified table_name name: '{catalog_name}.{schema_name}.{table_name}_{timestamp}' Raises: MLOpsToolkitInvalidSchemaNameException: If schema_name doesn't end with '_dataset_bronze'. MLOpsToolkitColumnNotFoundException: If DataFrame is missing 'input' or primary key column. MLOpsToolkitDuplicatePrimaryKeyException: If primary key column contains duplicates. MLOpsToolkitSchemaNotFoundException: If schema_name doesn't exist. Example: >>> df = spark.createDataFrame([ ... {"row_id": "1", "input": "acme corp", "candidates": ["Acme", "ACME Corp"]}, ... {"row_id": "2", "input": "widgets inc", "candidates": ["Widgets", "Widgets Inc"]}, ... ]) >>> table_name = create_eval_table( ... schema_name="vendor_tagging_dataset_bronze", ... table_name="vendor_eval", ... primary_key="row_id", ... df=df, ... description="Vendor name tagging evaluation set v1", ... tags={"domain": "vendor", "version": "1.0"} ... ) >>> print(table_name) 'yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052' """ logger = get_logger() validate_eval_dataframe(df, primary_key) # Validate inputs with AggregateExceptionHandler() as exc_handler: exc_handler.collect(validate_schema_name, schema_name) exc_handler.collect(validate_eval_dataframe, df, primary_key) schema_full_name = f"{catalog_name}.{schema_name}" if not schema_exists(schema_full_name): exc_handler.collect_raise( MLOpsToolkitSchemaNotFoundException, schema_full_name ) exc_handler.raise_if_any() # Generate timestamped table_name name timestamped_name = generate_timestamped_table_name(table_name) full_table_name = f"{catalog_name}.{schema_name}.{timestamped_name}" logger.info(f"Creating eval table_name: {full_table_name}") # Create the table_name with suppress_table_creation_logs(): create_table( schema_name=schema_name, table_name=timestamped_name, query=df, catalog_name=catalog_name, overwrite=False, table_comment=description, spark_options={"mergeSchema": "true"}, ) add_tags_to_table(full_table_name, tags) logger.info(f"Created eval table_name: {full_table_name} with {df.count()} rows") return full_table_name
[docs] def upsert_eval_table( table_name: str, df: "DataFrame", primary_key: str, *, create_new_version: bool = True, ) -> str: """Upsert rows into an existing evaluation table_name or create a new version. If create_new_version is True (default), creates a new timestamped table_name with the combined data. If False, performs an in-place MERGE operation on the existing table_name (requires table_name to be unlocked). Args: table_name: Fully qualified table_name name (catalog_name.schema_name.table_name). df: Spark DataFrame containing rows to upsert. Must include the primary key column and 'input' column. primary_key: Name of the primary key column for merge matching. create_new_version: If True, create a new timestamped table_name with merged data. If False, update the existing table_name in place. Returns: Fully qualified table_name name (new name if versioned, same if in-place). Raises: MLOpsToolkitTableNotFoundException: If table_name doesn't exist. MLOpsToolkitEvalTableLockedException: If table_name is locked and create_new_version is False. MLOpsToolkitColumnNotFoundException: If DataFrame schema_name doesn't match existing table_name. Example: >>> new_rows = spark.createDataFrame([ ... {"row_id": "3", "input": "new vendor", "candidates": ["New Vendor"]} ... ]) >>> new_table = upsert_eval_table( ... table_name="yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052", ... df=new_rows, ... primary_key="row_id", ... create_new_version=True ... ) """ logger = get_logger() spark = ydbc.get_spark_session() assert_table_exists(table_name) validate_eval_dataframe(df, primary_key) if create_new_version: existing_df = spark.table(table_name) merged_df = ( existing_df.alias("existing") .join(df.alias("new"), on=primary_key, how="left_anti") .unionByName(df, allowMissingColumns=True) ) parts = table_name.split(".") catalog_name, schema_name, table_name = parts[0], parts[1], parts[2] base_name = get_base_table_name(table_name) # Create new versioned table_name timestamped_name = generate_timestamped_table_name(base_name) new_full_name = f"{catalog_name}.{schema_name}.{timestamped_name}" with suppress_table_creation_logs(): create_table( schema_name=schema_name, table_name=timestamped_name, query=merged_df, catalog_name=catalog_name, overwrite=False, spark_options={"mergeSchema": "true"}, ) logger.info( f"Created new version: {new_full_name} with {merged_df.count()} rows" ) return new_full_name else: # In-place update - check lock first lock_info = get_table_lock_info(table_name) if lock_info is not None: raise MLOpsToolkitEvalTableLockedException( table_name, locked_by=lock_info.get("locked_by"), reason=lock_info.get("reason"), ) # Perform MERGE operation df.createOrReplaceTempView("upsert_data") merge_sql = f""" MERGE INTO {table_name} AS target USING upsert_data AS source ON target.{primary_key} = source.{primary_key} WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT * """ spark.sql(merge_sql) logger.info(f"Updated table_name in place: {table_name}") return table_name
[docs] def lock_eval_table( table_name: str, *, reason: str | None = None, locked_by: str | None = None, ) -> None: """Lock an evaluation table_name to prevent modifications. Applies a UC tag to mark the table_name as locked. Locked tables cannot be modified via upsert_eval_table (in-place mode) or deleted. Args: table_name: Fully qualified table_name name (catalog_name.schema_name.table_name). reason: Optional reason for locking (stored in tag value). locked_by: Optional identifier of who locked the table_name. Defaults to current user from Spark session. Raises: MLOpsToolkitTableNotFoundException: If table_name doesn't exist. MLOpsToolkitEvalTableLockedException: If table_name is already locked. Example: >>> lock_eval_table( ... table_name="yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052", ... reason="Production eval set - do not modify", ... locked_by="ml-team" ... ) Notes: Lock is implemented via UC tags: - Tag key: 'eval_locked' - Tag value: JSON with 'locked_at', 'locked_by', 'reason' """ logger = get_logger() assert_table_exists(table_name) # Get current user if not provided if locked_by is None: spark = ydbc.get_spark_session() try: locked_by = spark.sql("SELECT current_user()").collect()[0][0] except Exception: locked_by = "unknown" set_table_lock(table_name, locked_by=locked_by, reason=reason) logger.info( f"Locked eval table_name: {table_name} (by: {locked_by}, reason: {reason})" )
[docs] def unlock_eval_table( table_name: str, *, force: bool = False, ) -> None: """Unlock a previously locked evaluation table_name. Removes the lock tag from the table_name, allowing modifications. Args: table_name: Fully qualified table_name name (catalog_name.schema_name.table_name). force: If True, unlock even if locked by a different user. Defaults to False. Raises: MLOpsToolkitTableNotFoundException: If table_name doesn't exist. ValueError: If table_name is not locked. PermissionError: If table_name was locked by different user and force=False. Example: >>> unlock_eval_table( ... table_name="yd_tagging_platform_evals.vendor_tagging_dataset_bronze.vendor_eval_20260126_143052" ... ) """ logger = get_logger() assert_table_exists(table_name) # Get current user for permission check current_user = None if not force: spark = ydbc.get_spark_session() try: current_user = spark.sql("SELECT current_user()").collect()[0][0] except Exception: pass remove_table_lock(table_name, force=force, current_user=current_user) logger.info(f"Unlocked eval table_name: {table_name}")