Source code for ml_toolkit.functions.llm.udtf.serving_endpoint
"""Serving Endpoint UDTF implementation.
Provides UDTF for querying specific served models in multi-version serving endpoints
using the Databricks REST API.
This follows the standard UDTF template pattern using base.setup_udtf.
"""
import json
from yipit_databricks_client import post
from yipit_databricks_client.helpers.settings import get_api_token, get_deployment_name
from ml_toolkit.functions.llm.udtf.base import setup_udtf
from ml_toolkit.functions.llm.udtf.config import ServingEndpointConfig
from ml_toolkit.ops.helpers.logger import get_logger
logger = get_logger()
def serving_endpoint_process_row(
config: ServingEndpointConfig, user_prompt: str, row_data: dict
) -> str:
"""Process a single row by calling the serving endpoint via REST API.
This is the ``process_fn`` passed to :func:`base.setup_udtf`. The base
layer handles retry, timestamps, and error wrapping.
Args:
config: ServingEndpointConfig instance
user_prompt: The input prompt (from the 'prompt' column)
row_data: Full row data dict
Returns:
Model output string
"""
# Extract parameters from row (with defaults)
endpoint_name = row_data.get("endpoint_name", config.endpoint_name)
served_model_name = row_data.get("served_model_name", config.served_model_name)
max_tokens = row_data.get("max_tokens", 512)
temperature = row_data.get("temperature", 0.7)
# Validate required fields
if not endpoint_name or not served_model_name:
raise ValueError(
"Missing required fields: endpoint_name and/or served_model_name"
)
# Build REST API URL (without api version prefix)
url = f"serving-endpoints/{endpoint_name}/served-models/{served_model_name}/invocations"
# Build request payload using messages format (ChatML/OpenAI-style)
# This is the format expected by custom model endpoints like QLora fine-tuned models
payload = {
"messages": [{"role": "user", "content": user_prompt}],
"max_tokens": max_tokens,
"temperature": temperature,
}
# Make request using ydbc.post with explicit credentials from config
# This is necessary in serverless Databricks where workers cannot access deployment config
response = post(
url,
params=payload,
timeout=config.timeout_sec,
include_api_in_base_url=False,
deployment_name=config.deployment_name,
api_token=config.api_token,
)
response.raise_for_status()
result = response.json()
# Extract output from response
# Response format for chat completion: {"choices": [{"message": {"content": "..."}}]}
if "choices" in result and len(result["choices"]) > 0:
# ChatML/OpenAI-style response
choice = result["choices"][0]
if "message" in choice and "content" in choice["message"]:
output = choice["message"]["content"]
elif "text" in choice:
# Fallback for text completion format
output = choice["text"]
else:
output = json.dumps(choice)
elif "predictions" in result and len(result["predictions"]) > 0:
# Standard Databricks serving endpoint format
output = result["predictions"][0]
else:
output = json.dumps(result)
return str(output).strip()
[docs]
def setup_serving_endpoint_udtf(
config: ServingEndpointConfig,
name: str = "udtf_serving_endpoint",
prompt_column: str = "prompt",
):
"""Setup and register the serving endpoint UDTF.
Uses the generic setup_udtf from base.py for consistent behavior.
Args:
config: ServingEndpointConfig instance
name: UDTF name for SQL registration
prompt_column: Column name containing the user prompt
Returns:
The registered UDTF class
"""
# Capture deployment credentials on the driver (where setup runs)
# Store them in config so they're serialized with the UDTF to workers
if config.deployment_name is None or config.api_token is None:
try:
config.deployment_name = get_deployment_name()
config.api_token = get_api_token(config.deployment_name)
logger.info(
f"Captured deployment credentials for UDTF: {config.deployment_name}"
)
except Exception as e:
logger.warning(f"Failed to capture deployment credentials: {e}")
# Leave as None - will fail on workers but with clearer error
return setup_udtf(
config=config,
process_fn=serving_endpoint_process_row,
name=name,
prompt_column=prompt_column,
)
# Create default UDTF with default config
# Note: Set register_in_spark=False to avoid eager registration at import time
# The UDTF will be registered when query_serving_endpoint() is called
default_config = ServingEndpointConfig()
UDTFServingEndpoint = None # Will be lazily initialized