"""Helper functions for querying serving endpoints."""
import uuid
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from yipit_databricks_client import get_spark_session
from ml_toolkit.ops.helpers.logger import get_logger
logger = get_logger()
# Module-level flag for lazy UDTF registration
_default_udtf_registered = False
def _ensure_udtf_registered(spark):
"""Ensure the default serving endpoint UDTF is registered.
This is called lazily on first use to avoid eager registration at import time.
"""
global _default_udtf_registered
if not _default_udtf_registered:
from ml_toolkit.functions.llm.udtf.serving_endpoint import (
ServingEndpointConfig,
setup_serving_endpoint_udtf,
)
logger.debug("Registering default UDTF 'udtf_serving_endpoint'")
config = ServingEndpointConfig()
setup_serving_endpoint_udtf(config, name="udtf_serving_endpoint")
_default_udtf_registered = True
def get_latest_version_from_endpoint(endpoint_name: str) -> int:
"""Get the latest (highest) version number from a serving endpoint.
Args:
endpoint_name: Name of the serving endpoint
Returns:
Latest version number as an integer
Raises:
ValueError: If no valid versions found in the endpoint
"""
from ml_toolkit.functions.llm.serving_endpoints.function import (
get_serving_endpoint as _get_serving_endpoint,
)
endpoint_info = _get_serving_endpoint(endpoint_name)
config = endpoint_info.get("config", {})
served_entities = config.get("served_entities", [])
# Extract all version numbers from served entity names
available_versions = []
for entity in served_entities:
entity_name = entity["name"]
# Extract version from pattern: v{version}
if entity_name.startswith("v") and entity_name[1:].isdigit():
try:
available_versions.append(int(entity_name[1:]))
except ValueError:
continue
if not available_versions:
raise ValueError(
f"No valid versions found in endpoint '{endpoint_name}'. "
f"Expected served entities with names like 'v1', 'v2', etc."
)
latest_version = max(available_versions)
logger.info(
f"Resolved latest version for endpoint '{endpoint_name}': v{latest_version} "
f"(available versions: {sorted(available_versions)})"
)
return latest_version
[docs]
def query_serving_endpoint(
df: DataFrame,
endpoint_name: str,
version: int | None = None,
prompt_column: str = "prompt",
prompt_template: str = None,
max_tokens: int = 512,
temperature: float = 0.7,
output_column: str = "output",
) -> DataFrame:
"""Query a specific model version in a multi-version serving endpoint.
This convenience function prepares the DataFrame and calls the UDTF, automatically
routing to the specified version (v1, v2, v3, etc.).
Args:
df: Input DataFrame with a prompt column
endpoint_name: Name of the serving endpoint
version: Integer version identifier (e.g., 1, 2, 3). Defaults to None (latest version).
Routes to served model "v{version}" (e.g., v1, v2, v3).
If None, automatically resolves to the highest available version.
prompt_column: Name of the column containing prompts (default: "prompt").
Ignored if prompt_template is specified.
prompt_template: Optional template string with <<column>> placeholders.
Use double angle brackets to reference DataFrame columns.
Examples: "Classify: <<text>>", "Extract entity: <<merchant_name>>"
max_tokens: Maximum output tokens (default: 512)
temperature: Sampling temperature (default: 0.7)
output_column: Name of the output column (default: "output")
Returns:
DataFrame with columns: {output_column}, start_timestamp, end_timestamp, error_message
Example:
>>> from ml_toolkit.functions.llm.serving_endpoints.helpers.query import query_serving_endpoint
>>> from pyspark.sql import SparkSession
>>>
>>> spark = SparkSession.builder.getOrCreate()
>>> df = spark.table("input_table")
>>>
>>> # Query latest version (default)
>>> result_df = query_serving_endpoint(
... df=df,
... endpoint_name="my_endpoint",
... prompt_column="text",
... max_tokens=1000,
... temperature=0.5
... )
>>>
>>> # Query specific version
>>> result_df = query_serving_endpoint(
... df=df,
... endpoint_name="my_endpoint",
... version=1,
... prompt_template="Classify this <<category>> text: <<content>>",
... max_tokens=1000,
... temperature=0.5
... )
"""
spark = get_spark_session()
# Ensure UDTF is registered (lazy initialization on first call)
_ensure_udtf_registered(spark)
# Get endpoint configuration to find the served model name for this version
from ml_toolkit.functions.llm.serving_endpoints.function import (
get_serving_endpoint as _get_serving_endpoint,
)
from ml_toolkit.ops.helpers.exceptions import (
MLOpsToolkitEndpointVersionNotFoundException,
)
endpoint_info = _get_serving_endpoint(endpoint_name)
config = endpoint_info.get("config", {})
served_entities = config.get("served_entities", [])
# Resolve version if not specified
if version is None:
version = get_latest_version_from_endpoint(endpoint_name)
# Find the served model name that matches our version pattern
# Pattern: v{version} (e.g., v1, v2, v3)
served_model_name = f"v{version}"
matching_entities = [
entity["name"]
for entity in served_entities
if entity["name"] == served_model_name
]
if not matching_entities:
# Extract available versions from served entity names
available_versions = []
for entity in served_entities:
entity_name = entity["name"]
# Extract version from pattern: v{version}
if entity_name.startswith("v") and entity_name[1:].isdigit():
try:
available_versions.append(int(entity_name[1:]))
except ValueError:
continue
raise MLOpsToolkitEndpointVersionNotFoundException(
endpoint_name=endpoint_name,
version=version,
available_versions=available_versions,
)
logger.info(
f"Resolved version '{version}' to served model name '{served_model_name}'"
)
# Prepare DataFrame with required columns
df_prepared = df.withColumn("endpoint_name", F.lit(endpoint_name))
df_prepared = df_prepared.withColumn("served_model_name", F.lit(served_model_name))
df_prepared = df_prepared.withColumn("max_tokens", F.lit(max_tokens))
df_prepared = df_prepared.withColumn("temperature", F.lit(temperature))
# Handle prompt template if provided
if prompt_template is not None:
from ml_toolkit.functions.llm.helpers.prompt import (
find_cols_in_prompt_source,
transform_prompt_source_for_ai_query,
)
# Transform template to SQL format_string format
cols = find_cols_in_prompt_source(prompt_template)
transformed_prompt = transform_prompt_source_for_ai_query(prompt_template)
col_refs = ", ".join(cols)
# Apply template using format_string
df_prepared = df_prepared.withColumn(
"prompt", F.expr(f"format_string('{transformed_prompt}', {col_refs})")
)
elif prompt_column != "prompt":
# If no template and prompt column is different, rename it
df_prepared = df_prepared.withColumnRenamed(prompt_column, "prompt")
# Create temp view and call UDTF
temp_view = f"_serving_endpoint_input_{uuid.uuid4().hex[:8]}"
df_prepared.createOrReplaceTempView(temp_view)
logger.info(
f"Querying serving endpoint '{endpoint_name}' version '{version}' (served model: '{served_model_name}')"
)
result_sql = f"""
SELECT *
FROM udtf_serving_endpoint(
TABLE(
SELECT endpoint_name, served_model_name, prompt, max_tokens, temperature
FROM {temp_view}
)
)
"""
result_df = spark.sql(result_sql)
if output_column != "output":
result_df = result_df.withColumnRenamed("output", output_column)
return result_df