Source code for ml_toolkit.functions.llm.serving_endpoints.helpers.query

"""Helper functions for querying serving endpoints."""

import uuid

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from yipit_databricks_client import get_spark_session

from ml_toolkit.ops.helpers.logger import get_logger

logger = get_logger()

# Module-level flag for lazy UDTF registration
_default_udtf_registered = False


def _ensure_udtf_registered(spark):
    """Ensure the default serving endpoint UDTF is registered.

    This is called lazily on first use to avoid eager registration at import time.
    """
    global _default_udtf_registered

    if not _default_udtf_registered:
        from ml_toolkit.functions.llm.udtf.serving_endpoint import (
            ServingEndpointConfig,
            setup_serving_endpoint_udtf,
        )

        logger.debug("Registering default UDTF 'udtf_serving_endpoint'")
        config = ServingEndpointConfig()
        setup_serving_endpoint_udtf(config, name="udtf_serving_endpoint")
        _default_udtf_registered = True


def get_latest_version_from_endpoint(endpoint_name: str) -> int:
    """Get the latest (highest) version number from a serving endpoint.

    Args:
        endpoint_name: Name of the serving endpoint

    Returns:
        Latest version number as an integer

    Raises:
        ValueError: If no valid versions found in the endpoint
    """
    from ml_toolkit.functions.llm.serving_endpoints.function import (
        get_serving_endpoint as _get_serving_endpoint,
    )

    endpoint_info = _get_serving_endpoint(endpoint_name)
    config = endpoint_info.get("config", {})
    served_entities = config.get("served_entities", [])

    # Extract all version numbers from served entity names
    available_versions = []
    for entity in served_entities:
        entity_name = entity["name"]
        # Extract version from pattern: v{version}
        if entity_name.startswith("v") and entity_name[1:].isdigit():
            try:
                available_versions.append(int(entity_name[1:]))
            except ValueError:
                continue

    if not available_versions:
        raise ValueError(
            f"No valid versions found in endpoint '{endpoint_name}'. "
            f"Expected served entities with names like 'v1', 'v2', etc."
        )

    latest_version = max(available_versions)
    logger.info(
        f"Resolved latest version for endpoint '{endpoint_name}': v{latest_version} "
        f"(available versions: {sorted(available_versions)})"
    )
    return latest_version


[docs] def query_serving_endpoint( df: DataFrame, endpoint_name: str, version: int | None = None, prompt_column: str = "prompt", prompt_template: str = None, max_tokens: int = 512, temperature: float = 0.7, output_column: str = "output", ) -> DataFrame: """Query a specific model version in a multi-version serving endpoint. This convenience function prepares the DataFrame and calls the UDTF, automatically routing to the specified version (v1, v2, v3, etc.). Args: df: Input DataFrame with a prompt column endpoint_name: Name of the serving endpoint version: Integer version identifier (e.g., 1, 2, 3). Defaults to None (latest version). Routes to served model "v{version}" (e.g., v1, v2, v3). If None, automatically resolves to the highest available version. prompt_column: Name of the column containing prompts (default: "prompt"). Ignored if prompt_template is specified. prompt_template: Optional template string with <<column>> placeholders. Use double angle brackets to reference DataFrame columns. Examples: "Classify: <<text>>", "Extract entity: <<merchant_name>>" max_tokens: Maximum output tokens (default: 512) temperature: Sampling temperature (default: 0.7) output_column: Name of the output column (default: "output") Returns: DataFrame with columns: {output_column}, start_timestamp, end_timestamp, error_message Example: >>> from ml_toolkit.functions.llm.serving_endpoints.helpers.query import query_serving_endpoint >>> from pyspark.sql import SparkSession >>> >>> spark = SparkSession.builder.getOrCreate() >>> df = spark.table("input_table") >>> >>> # Query latest version (default) >>> result_df = query_serving_endpoint( ... df=df, ... endpoint_name="my_endpoint", ... prompt_column="text", ... max_tokens=1000, ... temperature=0.5 ... ) >>> >>> # Query specific version >>> result_df = query_serving_endpoint( ... df=df, ... endpoint_name="my_endpoint", ... version=1, ... prompt_template="Classify this <<category>> text: <<content>>", ... max_tokens=1000, ... temperature=0.5 ... ) """ spark = get_spark_session() # Ensure UDTF is registered (lazy initialization on first call) _ensure_udtf_registered(spark) # Get endpoint configuration to find the served model name for this version from ml_toolkit.functions.llm.serving_endpoints.function import ( get_serving_endpoint as _get_serving_endpoint, ) from ml_toolkit.ops.helpers.exceptions import ( MLOpsToolkitEndpointVersionNotFoundException, ) endpoint_info = _get_serving_endpoint(endpoint_name) config = endpoint_info.get("config", {}) served_entities = config.get("served_entities", []) # Resolve version if not specified if version is None: version = get_latest_version_from_endpoint(endpoint_name) # Find the served model name that matches our version pattern # Pattern: v{version} (e.g., v1, v2, v3) served_model_name = f"v{version}" matching_entities = [ entity["name"] for entity in served_entities if entity["name"] == served_model_name ] if not matching_entities: # Extract available versions from served entity names available_versions = [] for entity in served_entities: entity_name = entity["name"] # Extract version from pattern: v{version} if entity_name.startswith("v") and entity_name[1:].isdigit(): try: available_versions.append(int(entity_name[1:])) except ValueError: continue raise MLOpsToolkitEndpointVersionNotFoundException( endpoint_name=endpoint_name, version=version, available_versions=available_versions, ) logger.info( f"Resolved version '{version}' to served model name '{served_model_name}'" ) # Prepare DataFrame with required columns df_prepared = df.withColumn("endpoint_name", F.lit(endpoint_name)) df_prepared = df_prepared.withColumn("served_model_name", F.lit(served_model_name)) df_prepared = df_prepared.withColumn("max_tokens", F.lit(max_tokens)) df_prepared = df_prepared.withColumn("temperature", F.lit(temperature)) # Handle prompt template if provided if prompt_template is not None: from ml_toolkit.functions.llm.helpers.prompt import ( find_cols_in_prompt_source, transform_prompt_source_for_ai_query, ) # Transform template to SQL format_string format cols = find_cols_in_prompt_source(prompt_template) transformed_prompt = transform_prompt_source_for_ai_query(prompt_template) col_refs = ", ".join(cols) # Apply template using format_string df_prepared = df_prepared.withColumn( "prompt", F.expr(f"format_string('{transformed_prompt}', {col_refs})") ) elif prompt_column != "prompt": # If no template and prompt column is different, rename it df_prepared = df_prepared.withColumnRenamed(prompt_column, "prompt") # Create temp view and call UDTF temp_view = f"_serving_endpoint_input_{uuid.uuid4().hex[:8]}" df_prepared.createOrReplaceTempView(temp_view) logger.info( f"Querying serving endpoint '{endpoint_name}' version '{version}' (served model: '{served_model_name}')" ) result_sql = f""" SELECT * FROM udtf_serving_endpoint( TABLE( SELECT endpoint_name, served_model_name, prompt, max_tokens, temperature FROM {temp_view} ) ) """ result_df = spark.sql(result_sql) if output_column != "output": result_df = result_df.withColumnRenamed("output", output_column) return result_df