Source code for ml_toolkit.functions.llm.udtf.serving_endpoint

"""Serving Endpoint UDTF implementation.

Provides UDTF for querying specific served models in multi-version serving endpoints
using the Databricks REST API.

This follows the standard UDTF template pattern using base.setup_udtf.
"""

import json

from yipit_databricks_client import post
from yipit_databricks_client.helpers.settings import get_api_token, get_deployment_name

from ml_toolkit.functions.llm.udtf.base import setup_udtf
from ml_toolkit.functions.llm.udtf.config import ServingEndpointConfig
from ml_toolkit.ops.helpers.logger import get_logger

logger = get_logger()


def serving_endpoint_process_row(
    config: ServingEndpointConfig, user_prompt: str, row_data: dict
) -> str:
    """Process a single row by calling the serving endpoint via REST API.

    This is the ``process_fn`` passed to :func:`base.setup_udtf`. The base
    layer handles retry, timestamps, and error wrapping.

    Args:
        config: ServingEndpointConfig instance
        user_prompt: The input prompt (from the 'prompt' column)
        row_data: Full row data dict

    Returns:
        Model output string
    """
    # Extract parameters from row (with defaults)
    endpoint_name = row_data.get("endpoint_name", config.endpoint_name)
    served_model_name = row_data.get("served_model_name", config.served_model_name)
    max_tokens = row_data.get("max_tokens", 512)
    temperature = row_data.get("temperature", 0.7)

    # Validate required fields
    if not endpoint_name or not served_model_name:
        raise ValueError(
            "Missing required fields: endpoint_name and/or served_model_name"
        )

    # Build REST API URL (without api version prefix)
    url = f"serving-endpoints/{endpoint_name}/served-models/{served_model_name}/invocations"

    # Build request payload using messages format (ChatML/OpenAI-style)
    # This is the format expected by custom model endpoints like QLora fine-tuned models
    payload = {
        "messages": [{"role": "user", "content": user_prompt}],
        "max_tokens": max_tokens,
        "temperature": temperature,
    }

    # Make request using ydbc.post with explicit credentials from config
    # This is necessary in serverless Databricks where workers cannot access deployment config
    response = post(
        url,
        params=payload,
        timeout=config.timeout_sec,
        include_api_in_base_url=False,
        deployment_name=config.deployment_name,
        api_token=config.api_token,
    )
    response.raise_for_status()

    result = response.json()

    # Extract output from response
    # Response format for chat completion: {"choices": [{"message": {"content": "..."}}]}
    if "choices" in result and len(result["choices"]) > 0:
        # ChatML/OpenAI-style response
        choice = result["choices"][0]
        if "message" in choice and "content" in choice["message"]:
            output = choice["message"]["content"]
        elif "text" in choice:
            # Fallback for text completion format
            output = choice["text"]
        else:
            output = json.dumps(choice)
    elif "predictions" in result and len(result["predictions"]) > 0:
        # Standard Databricks serving endpoint format
        output = result["predictions"][0]
    else:
        output = json.dumps(result)

    return str(output).strip()


[docs] def setup_serving_endpoint_udtf( config: ServingEndpointConfig, name: str = "udtf_serving_endpoint", prompt_column: str = "prompt", ): """Setup and register the serving endpoint UDTF. Uses the generic setup_udtf from base.py for consistent behavior. Args: config: ServingEndpointConfig instance name: UDTF name for SQL registration prompt_column: Column name containing the user prompt Returns: The registered UDTF class """ # Capture deployment credentials on the driver (where setup runs) # Store them in config so they're serialized with the UDTF to workers if config.deployment_name is None or config.api_token is None: try: config.deployment_name = get_deployment_name() config.api_token = get_api_token(config.deployment_name) logger.info( f"Captured deployment credentials for UDTF: {config.deployment_name}" ) except Exception as e: logger.warning(f"Failed to capture deployment credentials: {e}") # Leave as None - will fail on workers but with clearer error return setup_udtf( config=config, process_fn=serving_endpoint_process_row, name=name, prompt_column=prompt_column, )
# Create default UDTF with default config # Note: Set register_in_spark=False to avoid eager registration at import time # The UDTF will be registered when query_serving_endpoint() is called default_config = ServingEndpointConfig() UDTFServingEndpoint = None # Will be lazily initialized