Source code for ml_toolkit.functions.llm.serving_endpoints.interface

"""High-level interface for deploying and managing serving endpoints."""

from typing import Dict, List, Optional

from yipit_databricks_client.helpers.telemetry import track_usage

from ml_toolkit.functions.llm.serving_endpoints.constants import (
    DEFAULT_TIMEOUT,
    DEFAULT_WAIT_FOR_READY,
    EndpointScalingSize,
    ProvisioningType,
)
from ml_toolkit.functions.llm.serving_endpoints.function import (
    deploy_serving_endpoint as _deploy_serving_endpoint,
)
from ml_toolkit.functions.llm.serving_endpoints.function import (
    get_serving_endpoint as _get_serving_endpoint,
)
from ml_toolkit.functions.llm.serving_endpoints.function import (
    list_serving_endpoints as _list_serving_endpoints,
)
from ml_toolkit.functions.llm.serving_endpoints.helpers.deployment import (
    build_served_entity,
    check_endpoint_exists,
    extract_cost_component_from_endpoint,
)
from ml_toolkit.functions.llm.serving_endpoints.helpers.optimization import (
    resolve_provisioning_config,
)
from ml_toolkit.functions.llm.serving_endpoints.helpers.traffic import (
    calculate_automatic_traffic,
    calculate_traffic_with_specific_percentage,
    calculate_traffic_with_zero_for_new_version,
)
from ml_toolkit.functions.llm.serving_endpoints.helpers.validation import (
    build_served_model_name,
    infer_endpoint_name_from_model,
    validate_endpoint_name,
)
from ml_toolkit.ops.helpers.exceptions import (
    MLOpsToolkitEndpointAlreadyExistsException,
    MLOpsToolkitEndpointNotFoundException,
    MLOpsToolkitUCVersionNotFoundForEndpointException,
)
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.models import get_mlflow_model_version_obj
from ml_toolkit.ops.helpers.validation import for_loop_guardrail

logger = get_logger()


@track_usage
@for_loop_guardrail(min_interval_seconds=10)
[docs] def deploy_model_serving_endpoint( model_name: str, model_versions: List[int], endpoint_name: Optional[str] = None, endpoint_scaling_size: EndpointScalingSize = "SMALL", provisioning_type: Optional[ProvisioningType] = None, traffic_config: Optional[Dict[int, int]] = None, cost_component_name: Optional[str] = None, tags: Optional[Dict[str, str]] = None, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT, ) -> Dict: """ Deploy one or more versions of a Unity Catalog model to a NEW serving endpoint. This function creates a new serving endpoint. If the endpoint already exists with any served models, it will raise MLOpsToolkitEndpointAlreadyExistsException. Use add_model_version_to_endpoint() to add versions to an existing endpoint. Automatically determines whether to use **provisioned throughput** or **GPU** config by calling the Databricks ``get-model-optimization-info`` API. For provisioned throughput models, token values are rounded to the model's chunk_size increment. For GPU models, the ``endpoint_scaling_size`` maps to a workload_type + workload_size combo. Parameters ^^^^^^^^^^ :param model_name: Unity Catalog model path (catalog.schema.model). Required. Example: ``"yd_ml_dpe.fine_tune_staging.qlora_tinyllama_model"`` :param model_versions: List of Unity Catalog model version numbers (integers) to deploy. Each entry corresponds to a version on the endpoint. Required. Example: ``[1]`` for single version, ``[1, 2]`` for two versions. The served model names will be: ``v1``, ``v2``, etc. :param endpoint_name: Optional name of the serving endpoint to create. If not provided, the endpoint name will be automatically inferred from the model name (the last part after the final dot). Must not already exist with served models, or MLOpsToolkitEndpointAlreadyExistsException will be raised. Example: If model_name is ``"catalog.schema.my_model"``, endpoint_name defaults to ``"my_model"`` :param endpoint_scaling_size: Endpoint scaling size. Defaults to ``"SMALL"``. Must be one of: ``"XSMALL"``, ``"SMALL"``, ``"MEDIUM"``, ``"LARGE"``, ``"XLARGE"``. For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos: XSMALL=GPU_SMALL/Small, SMALL=GPU_MEDIUM/Small, MEDIUM=GPU_MEDIUM/Medium, LARGE=MULTIGPU_MEDIUM/Medium, XLARGE=GPU_MEDIUM_8/Large. :param provisioning_type: Optional override for provisioning type. If not provided, the type is auto-detected by calling the Databricks optimization-info API. Must be ``"PROVISIONED_THROUGHPUT"`` or ``"GPU"``. :param traffic_config: Optional traffic distribution dict mapping version numbers to traffic percentages. If None, traffic is distributed equally. Example: ``{1: 70, 2: 30}`` means version 1 gets 70%, version 2 gets 30% :param cost_component_name: **Required** cost component for tracking. If not provided, will attempt to determine from settings or environment, but will raise ValueError if not found. Always specify this explicitly for production deployments. :param tags: Optional additional tags as key-value pairs :param scale_to_zero_enabled: Whether to enable scale-to-zero for the endpoint (default: True). When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :param wait_for_ready: Whether to wait for endpoint to be ready before returning :param timeout: Maximum seconds to wait if wait_for_ready is True Returns ^^^^^^^ Endpoint details dictionary Examples ^^^^^^^^ .. code-block:: python :caption: Deploy provisioned throughput model (auto-detected) from ml_toolkit.functions.llm.serving_endpoints import deploy_model_serving_endpoint endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.tiny_llama", model_versions=[1], endpoint_scaling_size="XSMALL", cost_component_name="ml_research", wait_for_ready=True, ) .. code-block:: python :caption: Deploy GPU model with explicit provisioning type endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.custom_model", model_versions=[1], endpoint_scaling_size="MEDIUM", provisioning_type="GPU", cost_component_name="data_integration_rd", ) .. code-block:: python :caption: Deploy multiple versions with A/B testing endpoint = deploy_model_serving_endpoint( model_name="catalog.schema.llama_model", model_versions=[1, 2], traffic_config={1: 50, 2: 50}, endpoint_scaling_size="LARGE", cost_component_name="ml_research", ) """ # Infer endpoint name from model name if not provided if endpoint_name is None: endpoint_name = infer_endpoint_name_from_model(model_name) logger.info( f"Endpoint name not provided. Inferred '{endpoint_name}' from model '{model_name}'" ) else: # Validate endpoint name if explicitly provided validate_endpoint_name(endpoint_name) # Check if endpoint already exists with models (strict mode) if check_endpoint_exists(endpoint_name): try: existing_endpoint = _get_serving_endpoint(endpoint_name) existing_entities = existing_endpoint.get("config", {}).get( "served_entities", [] ) if existing_entities: # Extract entity names from existing entities existing_entity_names = [e.get("name", "") for e in existing_entities] raise MLOpsToolkitEndpointAlreadyExistsException( endpoint_name=endpoint_name, existing_versions=existing_entity_names, ) except MLOpsToolkitEndpointAlreadyExistsException: # Re-raise our custom exception raise except Exception: # If we can't get endpoint details, continue (will be caught by API call later) pass # Validate model_versions is a list if not isinstance(model_versions, list) or len(model_versions) == 0: raise ValueError( f"model_versions must be a non-empty list, got: {model_versions}" ) # Validate all versions are positive integers for uc_version in model_versions: if not isinstance(uc_version, int) or uc_version < 1: raise ValueError( f"All model versions must be positive integers, got: {uc_version}" ) # Resolve provisioning type and chunk size via optimization-info API provisioning_type, chunk_size = resolve_provisioning_config( model_name=model_name, model_version=str(model_versions[0]), provisioning_type=provisioning_type, ) # Build served entities and convert traffic_config if needed served_entities = [] version_to_served_name = {} for version_index, uc_model_version in enumerate(model_versions, start=1): # The version identifier is based on position in the list (1, 2, 3...) version_id = version_index # Convert UC version integer to string for Databricks API uc_version_str = str(uc_model_version) # Build served entity name using new convention served_entity_name = build_served_model_name( model_name=model_name, version=version_id, endpoint_name=endpoint_name, ) # Track version to served name mapping for traffic config conversion version_to_served_name[version_id] = served_entity_name entity = build_served_entity( model_name=model_name, model_version=uc_version_str, served_entity_name=served_entity_name, provisioning_type=provisioning_type, endpoint_scaling_size=endpoint_scaling_size, scale_to_zero_enabled=scale_to_zero_enabled, chunk_size=chunk_size, ) served_entities.append(entity) logger.info( f"Deploying {len(served_entities)} model(s) to endpoint '{endpoint_name}' " f"(provisioning_type={provisioning_type})" ) # Convert traffic_config from version-based to served-name-based if needed converted_traffic_config = None if traffic_config is not None: if isinstance(traffic_config, dict) and len(traffic_config) > 0: first_key = next(iter(traffic_config.keys())) # Check if keys are integers (version-based) or strings (served-name-based) if isinstance(first_key, int): # Convert version-based to served-name-based converted_traffic_config = {} for version, percentage in traffic_config.items(): if version not in version_to_served_name: raise ValueError( f"Version {version} specified in traffic_config but not found in models" ) served_name = version_to_served_name[version] converted_traffic_config[served_name] = percentage logger.info( f"Converted version-based traffic config: {traffic_config} -> {converted_traffic_config}" ) else: # Already served-name-based, use as-is converted_traffic_config = traffic_config return _deploy_serving_endpoint( endpoint_name=endpoint_name, served_entities=served_entities, traffic_config=converted_traffic_config, cost_component_name=cost_component_name, tags=tags, wait_for_ready=wait_for_ready, timeout=timeout, )
@track_usage
[docs] def update_endpoint_traffic( traffic_config: Dict[int, int], model_name: str, endpoint_name: Optional[str] = None, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT, ) -> Dict: """ Update the traffic distribution for an existing multi-version endpoint. This is useful for gradual rollouts, canary deployments, or A/B testing scenarios where you want to adjust traffic percentages without redeploying models. Parameters ^^^^^^^^^^ :param traffic_config: Dict mapping version numbers to traffic percentages. Must sum to 100. Example: ``{1: 30, 2: 70}`` :param model_name: Unity Catalog model path (catalog.schema.model). Required. Used to infer the endpoint name if endpoint_name is not explicitly provided. :param endpoint_name: Optional name of the serving endpoint. If not provided, the endpoint name will be automatically inferred from model_name. :param wait_for_ready: Whether to wait for endpoint to be ready after update :param timeout: Maximum seconds to wait if wait_for_ready is True Returns ^^^^^^^ Updated endpoint details dictionary Examples ^^^^^^^^ .. code-block:: python :caption: Gradually increase traffic to new version using model_name # Start: 90% version 1, 10% version 2 update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 90, 2: 10}, ) # After monitoring: 50% split update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 50, 2: 50}, ) # Final: 100% version 2 update_endpoint_traffic( model_name="catalog.schema.my_model", traffic_config={1: 0, 2: 100}, ) .. code-block:: python :caption: Using explicit endpoint name update_endpoint_traffic( endpoint_name="my-endpoint", traffic_config={1: 50, 2: 50}, ) """ # Infer endpoint name from model name if not provided if endpoint_name is None: endpoint_name = infer_endpoint_name_from_model(model_name) logger.info( f"Endpoint name not provided. Inferred '{endpoint_name}' from model '{model_name}'" ) else: # Validate endpoint name if explicitly provided validate_endpoint_name(endpoint_name) # Get current endpoint configuration current_endpoint = _get_serving_endpoint(endpoint_name=endpoint_name) config = current_endpoint.get("config", {}) served_entities = config.get("served_entities", []) if not served_entities: raise ValueError(f"No served entities found in endpoint '{endpoint_name}'") # Build version to served name mapping # Pattern: v{version} (e.g., v1, v2, v3) version_to_served_name = {} for entity in served_entities: entity_name = entity["name"] # Extract version from pattern: v{version} if entity_name.startswith("v") and entity_name[1:].isdigit(): try: version = int(entity_name[1:]) version_to_served_name[version] = entity_name except ValueError: pass # Convert version-based traffic config to served-name-based converted_traffic_config = {} for version, percentage in traffic_config.items(): if version not in version_to_served_name: available_versions = list(version_to_served_name.keys()) raise ValueError( f"Version {version} not found in endpoint '{endpoint_name}'. " f"Available versions: {available_versions}" ) served_name = version_to_served_name[version] converted_traffic_config[served_name] = percentage logger.info( f"Updating traffic for endpoint '{endpoint_name}': {traffic_config} -> {converted_traffic_config}" ) # Extract cost_component_name from existing endpoint tags cost_component_name = extract_cost_component_from_endpoint(current_endpoint) if not cost_component_name: logger.warning( "No cost_component_name found in existing endpoint tags. " "Will attempt to use default from settings or environment." ) return _deploy_serving_endpoint( endpoint_name=endpoint_name, served_entities=served_entities, traffic_config=converted_traffic_config, cost_component_name=cost_component_name, wait_for_ready=wait_for_ready, timeout=timeout, )
@track_usage
[docs] def add_model_version_to_endpoint( version: int, model_name: str, endpoint_name: Optional[str] = None, endpoint_scaling_size: EndpointScalingSize = "SMALL", provisioning_type: Optional[ProvisioningType] = None, traffic_percentage: Optional[int] = None, redistribute_traffic: bool = True, scale_to_zero_enabled: bool = True, wait_for_ready: bool = DEFAULT_WAIT_FOR_READY, timeout: int = DEFAULT_TIMEOUT, ) -> Dict: """ Add a new model version to an existing serving endpoint. This function fetches the current endpoint configuration, auto-detects the model name and version from existing versions, and adds the new version with automatic traffic distribution. Automatically determines whether to use **provisioned throughput** or **GPU** config. Validates that the new entity's provisioning type matches the existing endpoint entities. Parameters ^^^^^^^^^^ :param version: Integer version identifier (e.g., 1, 2, 3). The served model name will be: ``v{version}`` (e.g., v1, v2, v3) :param model_name: Unity Catalog model path (catalog.schema.model). Required. Used to infer the endpoint name if endpoint_name is not explicitly provided. :param endpoint_name: Optional name of the existing serving endpoint. If not provided, the endpoint name will be automatically inferred from model_name. :param endpoint_scaling_size: Endpoint scaling size. Defaults to ``"SMALL"``. Must be one of: ``"XSMALL"``, ``"SMALL"``, ``"MEDIUM"``, ``"LARGE"``, ``"XLARGE"``. For provisioned throughput models, this maps to token-per-second ranges. For GPU models, this maps to workload_type + workload_size combos. :param provisioning_type: Optional override for provisioning type. If not provided, the type is auto-detected by calling the Databricks optimization-info API. Must be ``"PROVISIONED_THROUGHPUT"`` or ``"GPU"``. :param traffic_percentage: Percentage of traffic to route to this new version (0-100). If None and redistribute_traffic=True, traffic is distributed equally. If specified, other versions' traffic will be proportionally reduced. :param redistribute_traffic: If True, automatically adjust traffic across all versions. If False, the new version gets 0% traffic (must manually call update_endpoint_traffic). :param scale_to_zero_enabled: Whether to enable scale-to-zero for the new version (default: True). When True, endpoint scales to zero when idle. When False, endpoint stays always warm. :param wait_for_ready: Whether to wait for endpoint to be ready before returning :param timeout: Maximum seconds to wait if wait_for_ready is True Returns ^^^^^^^ Updated endpoint details dictionary Examples ^^^^^^^^ .. code-block:: python :caption: Add version using model_name (recommended) from ml_toolkit.functions.llm.serving_endpoints import add_model_version_to_endpoint endpoint = add_model_version_to_endpoint( version=3, model_name="catalog.schema.my_model", endpoint_scaling_size="MEDIUM", ) .. code-block:: python :caption: Add version with explicit provisioning type endpoint = add_model_version_to_endpoint( version=2, endpoint_name="production_endpoint", endpoint_scaling_size="LARGE", provisioning_type="GPU", traffic_percentage=10, ) """ # Infer endpoint name from model name if not provided if endpoint_name is None: endpoint_name = infer_endpoint_name_from_model(model_name) logger.info( f"Endpoint name not provided. Inferred '{endpoint_name}' from model '{model_name}'" ) else: # Validate endpoint name if explicitly provided validate_endpoint_name(endpoint_name) logger.info(f"Adding version {version} to existing endpoint '{endpoint_name}'") # Check if endpoint exists if not check_endpoint_exists(endpoint_name): raise MLOpsToolkitEndpointNotFoundException(endpoint_name) # Get current endpoint configuration current_endpoint = _get_serving_endpoint(endpoint_name=endpoint_name) config = current_endpoint.get("config", {}) existing_entities = config.get("served_entities", []) existing_traffic = config.get("traffic_config", {}) if not existing_entities: raise ValueError( f"Endpoint '{endpoint_name}' has no served entities. " f"Use deploy_model_serving_endpoint to create a new endpoint." ) # Auto-detect model_name from existing entities model_name = existing_entities[0]["entity_name"] logger.info(f"Auto-detected model_name from endpoint: {model_name}") # Use the endpoint version number as the UC model version model_version = str(version) logger.info( f"Using UC model version {model_version} for endpoint version v{version}" ) # Validate that the UC model version exists try: get_mlflow_model_version_obj(model_name, int(model_version)) logger.info( f"Validated: UC model version {model_version} exists for {model_name}" ) except Exception as e: error_msg = str(e) if ( "RESOURCE_DOES_NOT_EXIST" in error_msg.upper() or "does not exist" in error_msg.lower() ): raise MLOpsToolkitUCVersionNotFoundForEndpointException( endpoint_name=endpoint_name, endpoint_version=version, model_name=model_name, uc_version=int(model_version), ) # Re-raise unexpected errors raise # Build served model name using new convention new_entity_name = build_served_model_name( model_name=model_name, version=version, endpoint_name=endpoint_name, ) # Extract existing served entity names existing_entity_names = [entity["name"] for entity in existing_entities] # Check if this entity name already exists if new_entity_name in existing_entity_names: raise ValueError( f"Served entity '{new_entity_name}' already exists in endpoint '{endpoint_name}'. " f"Use deploy_model_serving_endpoint to update it." ) logger.info( f"Current endpoint has {len(existing_entities)} model(s): {existing_entity_names}" ) # Resolve provisioning type and chunk size via optimization-info API provisioning_type, chunk_size = resolve_provisioning_config( model_name=model_name, model_version=model_version, provisioning_type=provisioning_type, ) # Validate consistency with existing entities _validate_provisioning_consistency( existing_entities, provisioning_type, endpoint_name ) new_entity = build_served_entity( model_name=model_name, model_version=model_version, served_entity_name=new_entity_name, provisioning_type=provisioning_type, endpoint_scaling_size=endpoint_scaling_size, scale_to_zero_enabled=scale_to_zero_enabled, chunk_size=chunk_size, ) # Add new entity to list all_entities = existing_entities + [new_entity] # Calculate traffic distribution using helper functions if not redistribute_traffic: # New model gets 0% traffic traffic_config = calculate_traffic_with_zero_for_new_version( existing_entity_names=existing_entity_names, existing_traffic=existing_traffic, new_entity_name=new_entity_name, ) logger.info( f"Adding '{new_entity_name}' with 0% traffic (redistribute_traffic=False)" ) elif traffic_percentage is not None: # Specific traffic percentage requested traffic_config = calculate_traffic_with_specific_percentage( existing_entity_names=existing_entity_names, existing_entities=existing_entities, existing_traffic=existing_traffic, new_entity_name=new_entity_name, traffic_percentage=traffic_percentage, ) remaining_traffic = 100 - traffic_percentage logger.info( f"Adding '{new_entity_name}' with {traffic_percentage}% traffic, " f"reducing existing models to {remaining_traffic}% total" ) else: # Distribute traffic equally across all models (including new one) all_entity_names = existing_entity_names + [new_entity_name] traffic_config = calculate_automatic_traffic( served_entity_names=all_entity_names, ) logger.info( f"Distributing traffic equally across {len(all_entities)} models: " f"{list(traffic_config.values())}%" ) logger.info(f"New traffic distribution: {traffic_config}") # Extract cost_component_name from existing endpoint tags cost_component_name = extract_cost_component_from_endpoint(current_endpoint) if not cost_component_name: logger.warning( "No cost_component_name found in existing endpoint tags. " "Will attempt to use default from settings or environment." ) # Deploy updated configuration return _deploy_serving_endpoint( endpoint_name=endpoint_name, served_entities=all_entities, traffic_config=traffic_config, cost_component_name=cost_component_name, wait_for_ready=wait_for_ready, timeout=timeout, )
def _validate_provisioning_consistency( existing_entities: List[Dict], resolved_type: ProvisioningType, endpoint_name: str, ) -> None: """Validate that new entity provisioning type matches existing endpoint entities.""" if not existing_entities: return first_entity = existing_entities[0] existing_has_throughput = ( "max_provisioned_throughput" in first_entity or "min_provisioned_throughput" in first_entity ) existing_has_gpu = "workload_type" in first_entity if resolved_type == ProvisioningType.PROVISIONED_THROUGHPUT and existing_has_gpu: raise ValueError( f"Endpoint '{endpoint_name}' uses GPU config but new version requests " f"provisioned throughput. All entities on an endpoint must use the same " f"provisioning type. Pass provisioning_type='GPU' to match existing config." ) if resolved_type == ProvisioningType.GPU and existing_has_throughput: raise ValueError( f"Endpoint '{endpoint_name}' uses provisioned throughput but new version " f"requests GPU config. All entities on an endpoint must use the same " f"provisioning type. Pass provisioning_type='PROVISIONED_THROUGHPUT' to " f"match existing config." ) @track_usage
[docs] def get_serving_endpoint( model_name: str, endpoint_name: Optional[str] = None, ) -> Dict: """ Get details of a serving endpoint. Parameters ^^^^^^^^^^ :param model_name: Unity Catalog model name (catalog.schema.model_name). Required parameter used to infer endpoint_name if not provided. :param endpoint_name: Name of the endpoint. Optional - if not provided, automatically inferred from model_name. Returns ^^^^^^^ Endpoint details dictionary containing configuration, state, and metadata Examples ^^^^^^^^ .. code-block:: python from ml_toolkit.functions.llm.serving_endpoints import get_serving_endpoint # Endpoint name automatically inferred from model name endpoint = get_serving_endpoint(model_name="catalog.schema.tiny_llama") # Endpoint name inferred as "tiny_llama" print(f"State: {endpoint['state']['ready']}") print(f"Served entities: {endpoint['config']['served_entities']}") # Or specify explicit endpoint name endpoint = get_serving_endpoint( model_name="catalog.schema.tiny_llama", endpoint_name="my-custom-endpoint" ) """ # Infer endpoint name from model name if not provided if endpoint_name is None: endpoint_name = infer_endpoint_name_from_model(model_name) logger.info( f"Endpoint name not provided. Inferred '{endpoint_name}' from model '{model_name}'" ) else: validate_endpoint_name(endpoint_name) return _get_serving_endpoint(endpoint_name=endpoint_name)
@track_usage
[docs] def list_serving_endpoints( filter_tags: Optional[Dict[str, str]] = None, ) -> List[Dict]: """ List all serving endpoints in the workspace. Parameters ^^^^^^^^^^ :param filter_tags: Optional dict of tags to filter by. Only endpoints with ALL specified tags matching will be returned Returns ^^^^^^^ List of endpoint dictionaries Examples ^^^^^^^^ .. code-block:: python from ml_toolkit.functions.llm.serving_endpoints import list_serving_endpoints # List all endpoints endpoints = list_serving_endpoints() for ep in endpoints: print(f"{ep['name']}: {ep['state']['ready']}") # Filter by cost component my_team_endpoints = list_serving_endpoints( filter_tags={"cost_component_name": "ml_research"} ) # Filter by multiple tags prod_endpoints = list_serving_endpoints( filter_tags={"env": "production", "team": "ai-platform"} ) """ return _list_serving_endpoints(filter_tags=filter_tags)