Source code for ml_toolkit.functions.llm.inference.interface

"""User-facing interface for GPU inference via ``run_remote()``.

Supports both **serverless** and **classic** (Ray-on-Spark) compute via
the ``compute_type`` parameter.

Example (serverless — default)::

    from ml_toolkit.functions.llm import run_inference

    remote_run = run_inference(
        "catalog.schema.input_table",
        "catalog.schema.my_model",
        "catalog.schema.output",
        uc_volumes_path="/Volumes/catalog/schema/vol",
    )
    print(remote_run.databricks_url)
    result = remote_run.get_result()

Example (classic p5.48xlarge with Capacity Blocks)::

    remote_run = run_inference(
        "catalog.schema.input_table",
        "catalog.schema.my_model",
        "catalog.schema.output",
        compute_type="classic",
        gpu_type="h100",
        num_gpus=16,  # → driver (8 GPUs) + 1 worker (8 GPUs)
        capacity_reservation_id="cr-0abc1234def56789a",
        availability_zone="us-east-1a",
    )
"""

from __future__ import annotations

from typing import Literal, Optional

from pyspark.sql import DataFrame
from yipit_databricks_client.helpers.telemetry import track_usage

from ml_toolkit.functions.llm.inference.anyscale_function import AnyscaleKwargs
from ml_toolkit.functions.llm.inference.classic_function import (
    _run_classic_inference_local,
)
from ml_toolkit.functions.llm.inference.function import RemoteInferenceRun
from ml_toolkit.functions.llm.inference.serverless_function import (
    _run_serverless_inference_local,
)
from ml_toolkit.ml.llm.inference.serverless.vllm_serverless import (
    DEFAULT_UC_VOLUMES_PATH,
)
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.validation import (
    assert_model_alias_exists,
    assert_model_version_exists,
    assert_table_exists,
)


def _get_installed_version(package: str) -> str | None:
    """Return the installed version of *package*, or ``None`` if missing."""
    from importlib.metadata import PackageNotFoundError, version

    try:
        return version(package)
    except PackageNotFoundError:
        return None


def _version_satisfies(installed: str | None, spec: str) -> bool:
    """Check whether *installed* version satisfies a PEP 440 *spec*."""
    if installed is None:
        return False
    from packaging.specifiers import SpecifierSet

    return installed in SpecifierSet(spec)


# -- Target versions ---------------------------------------------------------
_TORCH_VERSION = "==2.9.0"
_FLASH_ATTN_VERSION = "==2.8.3"
_RAY_VERSION = "==2.47.1"
_VLLM_VERSION = "==0.13.0"
_HF_TRANSFER_VERSION = "==0.1.9"
_EINOPS_VERSION = "==0.8.2"
_OPENCV_VERSION = "==4.12.0.88"
_PYARROW_VERSION = "==21"
_CUDA_PYTHON_VERSION = ">=12,<13"


def install_inference_gpu_requirements() -> None:
    """Install GPU inference dependencies, skipping what's already satisfied.

    Auto-detects whether the current environment is **serverless** or
    **classic** (standard) via ``detect_compute_environment()``.

    On classic clusters (e.g. DBR 18.1 ML), many packages are
    pre-installed (torch, flash-attn, transformers).  This function
    checks installed versions and only installs what's missing or
    outdated — avoiding the expensive torch reinstall and flash-attn
    source compilation when not needed.

    On serverless environments, the full install is performed since
    packages are not pre-installed.

    Uses the IPython ``%pip`` magic (via ``get_ipython().run_line_magic``)
    instead of raw ``subprocess`` so that Databricks tracks environment
    changes for worker node propagation and environment snapshotting.

    Install order (when needed):
        1. Base requirements (from ``gpu_requirements.txt``)
        2. ``torch`` — skipped on classic if already ``>=2.9.0``
        3. Build tooling + ``flash-attn`` — skipped if already satisfied
        4. ``opencv-python-headless`` (OpenSSL FIPS fix)
        5. ``ray[data]``, ``hf_transfer``, ``einops``
        6. ``vllm``
    """
    from IPython import get_ipython

    from ml_toolkit.ops.helpers.environment import detect_compute_environment

    env = detect_compute_environment()
    is_serverless = env == "serverless"

    logger = get_logger()
    logger.info(f"Detected compute environment: {env}")

    _ipy = get_ipython()

    def _pip(*args: str) -> None:
        _ipy.run_line_magic("pip", "install " + " ".join(args))

    # -- 1. Base requirements (always) ----------------------------------------
    import os

    from ml_toolkit.functions.llm.inference.function import GPU_REQUIREMENTS_PATH
    from ml_toolkit.ops.settings import REPO_BASE_PATH

    req_file = os.path.join(REPO_BASE_PATH, GPU_REQUIREMENTS_PATH)
    if os.path.exists(req_file):
        _pip("-r", req_file)

    # -- 2. torch -------------------------------------------------------------
    # Serverless: always install cu128 build (clean env).
    # Classic: skip if a compatible torch is already installed (e.g.
    # DBR 18.1 ships torch 2.9.0+cu129 which works with vLLM).
    torch_installed = _get_installed_version("torch")
    torch_ok = not is_serverless and _version_satisfies(torch_installed, _TORCH_VERSION)

    if torch_ok:
        logger.info(
            f"torch {torch_installed} already satisfies {_TORCH_VERSION}, skipping"
        )
    else:
        logger.info("Installing torch+cu128")
        _pip(
            "--no-cache-dir",
            "torch==2.9.0+cu128",
            "--index-url",
            "https://download.pytorch.org/whl/cu128",
        )

    # -- 3. flash-attn --------------------------------------------------------
    # Expensive source compilation (~5-10 min).  Skip if already installed
    # at the target version (e.g. DBR 18.1 ships flash-attn 2.8.3
    # pre-compiled against its torch build).
    flash_installed = _get_installed_version("flash-attn")
    flash_ok = _version_satisfies(flash_installed, _FLASH_ATTN_VERSION)

    if flash_ok:
        logger.info(
            f"flash-attn {flash_installed} already satisfies {_FLASH_ATTN_VERSION}, skipping"
        )
    else:
        logger.info("Installing build tooling + flash-attn from source")
        _pip(
            "-U", "--no-cache-dir", "wheel==0.46.3", "ninja==1.13.0", "packaging==26.0"
        )
        _pip(
            "--force-reinstall",
            "--no-cache-dir",
            "--no-build-isolation",
            "--no-deps",
            "flash-attn==2.8.3",
        )

    # -- 4. opencv-python-headless (OpenSSL FIPS fix) -------------------------
    opencv_installed = _get_installed_version("opencv-python-headless")
    if not _version_satisfies(opencv_installed, _OPENCV_VERSION):
        _pip("--no-cache-dir", "opencv-python-headless==4.12.0.88")

    # -- 5. ray, hf_transfer, einops -----------------------------------------
    # Ray 2.47.1 for all compute types (Databricks RayPatches incompatible
    # with BlockMetadata changes in Ray >=2.48).
    _ray_spec = _RAY_VERSION
    _ray_pin = "ray[data]==2.47.1"

    to_install = []
    ray_installed = _get_installed_version("ray")
    if not _version_satisfies(ray_installed, _ray_spec):
        to_install.append(_ray_pin)
    else:
        logger.info(f"ray {ray_installed} already satisfies {_ray_spec}, skipping")

    if not _version_satisfies(
        _get_installed_version("hf-transfer"), _HF_TRANSFER_VERSION
    ):
        to_install.append("hf_transfer==0.1.9")

    if not _version_satisfies(_get_installed_version("einops"), _EINOPS_VERSION):
        to_install.append("einops==0.8.2")

    if to_install:
        _pip(*to_install)

    # -- 6. vLLM --------------------------------------------------------------
    vllm_installed = _get_installed_version("vllm")
    if not _version_satisfies(vllm_installed, _VLLM_VERSION):
        _pip("vllm==0.13.0")
    else:
        logger.info(
            f"vllm {vllm_installed} already satisfies {_VLLM_VERSION}, skipping"
        )

    # -- 7. Re-pin ray after vllm (vllm may pull a newer ray as transitive dep)
    ray_after_vllm = _get_installed_version("ray")
    if not _version_satisfies(ray_after_vllm, _RAY_VERSION):
        logger.info(
            f"ray drifted to {ray_after_vllm} after vllm install, "
            f"re-pinning to {_ray_pin}"
        )
        _pip(_ray_pin)

    # -- 8. Classic-only: pyarrow + cuda-python --------------------------------
    # pyarrow 21 is needed for Arrow-based Spark↔Ray data exchange on
    # PySpark 4.x (DBR 18.1+).  cuda-python provides CUDA bindings that
    # vLLM requires on classic clusters (serverless bundles them).
    if not is_serverless:
        pyarrow_installed = _get_installed_version("pyarrow")
        if not _version_satisfies(pyarrow_installed, _PYARROW_VERSION):
            _pip("--force-reinstall", "--no-deps", "pyarrow==21")

        cuda_python_installed = _get_installed_version("cuda-python")
        if not _version_satisfies(cuda_python_installed, _CUDA_PYTHON_VERSION):
            _pip('"cuda-python>=12,<13"')

    logger.info("GPU inference requirements ready")


def _validate_three_part_name(value: str, label: str) -> None:
    """Raise ValueError if *value* is not a three-part dotted name."""
    parts = value.split(".")
    if len(parts) != 3 or not all(parts):
        raise ValueError(
            f"`{label}` must be in the format `catalog.schema.name`, got: {value!r}"
        )


@track_usage
[docs] def run_inference( input_source: str | DataFrame, model_name: str, output_table: str, *, prompt_column: str = "prompt", model_version: Optional[int | Literal["latest"]] = None, model_alias: Optional[str] = None, system_prompt: Optional[str] = None, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, min_p: Optional[float] = None, presence_penalty: Optional[float] = 0.0, repetition_penalty: Optional[float] = 1.0, max_model_len: Optional[int] = None, num_gpus: int = 4, gpu_type: str = "a10", concurrency: Optional[int] = None, batch_size: Optional[int] = None, tensor_parallel_size: int = 1, gpu_memory_utilization: Optional[float] = 0.85, max_num_seqs: Optional[int] = None, max_num_batched_tokens: Optional[int] = None, kv_cache_dtype: Optional[str] = None, dtype: Optional[str] = None, quantization: Optional[str] = None, uc_volumes_path: Optional[str] = None, auto_build_prompt: bool = True, output_column: str = "model_output", pk_column: Optional[str] = None, table_operation: Optional[Literal["overwrite", "append"]] = "overwrite", ray_object_store_memory: float = 0.5, http_pool_maxsize: int = 64, trigger_remote: bool = True, wait: bool = False, compute_type: Literal["serverless", "classic", "anyscale"] = "serverless", capacity_reservation_id: Optional[str] = "auto", availability_zone: Optional[str] = "auto", ray_kwargs: Optional[dict] = None, env_vars: Optional[dict[str, str]] = None, output_schema: Optional[object] = None, chat_template_kwargs: Optional[dict] = None, output_storage_location: Optional[str] = None, output_bucket: Optional[str] = None, checkpoint_path: Optional[str] = None, anyscale_kwargs: "Optional[AnyscaleKwargs | dict]" = None, ) -> RemoteInferenceRun | dict | None: """Run vLLM batch inference on Databricks GPU compute. Supports two compute backends via ``compute_type``: - ``"serverless"`` (default) — Databricks Serverless GPU. Uses A10 or H100 accelerators. Requires ``uc_volumes_path`` for data staging. - ``"classic"`` — Databricks classic cluster with p5.48xlarge (8x H100) or p6-b200.48xlarge (8x B200) nodes. Uses Ray-on-Spark for distributed inference. GPUs in multiples of 8. The driver contributes 8 GPUs; the number of additional worker nodes is auto-computed as ``max(0, (num_gpus - 8) // 8)``. Supports AWS Capacity Blocks. When ``trigger_remote=True`` (default), the function validates inputs, serializes parameters, and submits a remote job via ``run_remote()``. When ``trigger_remote=False``, inference runs directly on the current cluster (must already have GPUs available). Args: input_source: Fully qualified UC table name (``catalog.schema.table``) or DataFrame. model_name: UC model name (``catalog.schema.model``). output_table: Fully qualified output table name. prompt_column: Column containing prompt text. Defaults to ``"prompt"``. model_version: Specific model version (int) or ``"latest"``. Mutually exclusive with ``model_alias``. ``None`` = production alias. model_alias: UC alias (e.g. ``"champion"``). Mutually exclusive with ``model_version``. system_prompt: System instruction for chat completions. max_new_tokens: Maximum tokens to generate per prompt. temperature: Sampling temperature (0.0 = greedy). top_p: Nucleus sampling threshold. top_k: Controls the number of top tokens to consider. Set to 0 (or -1) to consider all tokens. min_p: Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. presence_penalty: Penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens. repetition_penalty: Penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens. max_model_len: Maximum sequence length. num_gpus: Number of GPUs. Defaults to 4. For ``compute_type="classic"``, must be a multiple of 8 (H100 only) and counts the total across driver + workers. ``num_gpus=8`` runs driver-only; ``num_gpus=16`` runs driver + 1 worker node. gpu_type: GPU type — ``"a10"``, ``"h100"``, or ``"b200"``. Defaults to ``"a10"``. For ``compute_type="classic"``, must be ``"h100"`` or ``"b200"``. concurrency: Number of vLLM data-parallel replicas. batch_size: Rows per inference batch. tensor_parallel_size: GPUs per vLLM instance. gpu_memory_utilization: Fraction of GPU memory to use for vLLM. max_num_seqs: Maximum concurrent sequences per vLLM engine. max_num_batched_tokens: Maximum tokens per vLLM iteration step. kv_cache_dtype: Data type for KV cache (e.g. ``"fp8"``). dtype: Model weight data type for vLLM (e.g. ``"bfloat16"``, ``"float16"``, ``"auto"``). ``None`` = auto-resolved from model config. quantization: vLLM quantization method. ``"auto"`` (default) = resolved from model_config → GPU defaults (``"fp8"`` for H100/B200, ``None`` for A10). ``"fp8"`` = dynamically quantize weights to FP8_E4M3 at load time. ``None`` = don't pass quantization to vLLM (auto-detect from model). uc_volumes_path: UC Volume path for staging data. Required. auto_build_prompt: Auto-build prompt column from model template. output_column: Name of the output column containing model predictions. Defaults to ``"model_output"``. pk_column: Name of an existing column to use as the row identifier for joining inference results back to the input. If ``None`` (default), a synthetic row ID is generated automatically. Use this when your input already has a unique primary key column to avoid creating a redundant one. table_operation: How to write the output table — ``"overwrite"`` (default) replaces the table, ``"append"`` adds rows to an existing table. ray_object_store_memory: Fraction of cluster memory for Ray object store. Defaults to 0.5. http_pool_maxsize: Max concurrent HTTP connections for model download. trigger_remote: If ``True``, submit via ``run_remote()``. If ``False``, run directly on current cluster. wait: If ``True`` and ``trigger_remote=True``, block until completion and return the result dict. compute_type: ``"serverless"`` (default) or ``"classic"``. capacity_reservation_id: AWS Capacity Block reservation ID (e.g. ``"cr-0abc1234def56789a"``). Classic only. Configures the cluster with on-demand instances and the ``X-Databricks-AwsCapacityBlockId`` tag. Pass ``"auto"`` to fetch from the ``WORKSPACE_CONFIGURATION`` secret scope (keys ``GPU_CAPACITY_RESERVATION_ID`` and ``GPU_AVAILABILITY_ZONE``). availability_zone: AWS availability zone for the Capacity Block (e.g. ``"us-east-1a"``). Required when using ``capacity_reservation_id`` (auto-resolved when ``"auto"``). ray_kwargs: Dict of Ray-on-Spark tuning parameters forwarded to ``run_classic_inference`` (classic only). Supported keys: ``num_cpus_per_node``, ``num_cpus_head_node``, ``num_gpus_head_node``. env_vars: Optional dictionary of environment variables to set in each Ray worker process before inference starts. Ray actors do not inherit the driver's environment, so diagnostic variables like ``VLLM_LOGGING_LEVEL=DEBUG`` must be forwarded explicitly. Defaults to ``None`` (no extra variables). output_schema: Optional PySpark ``DataType`` describing the JSON shape the model is expected to produce (e.g. ``ArrayType(StructType([...]))``). When set, the output table gets a parsed ``model_output`` struct column alongside the raw text in ``model_output_raw``. Accepts a ``DataType`` instance, a DDL string (e.g. ``"array<struct<l1_category:string,...>>"``), or the dict form produced by :func:`ml_toolkit.ops.helpers.mlflow.serialize_pyspark_schema`. Required for HuggingFace model IDs (no MLflow signature); for UC models, overrides whatever the MLflow signature declared. chat_template_kwargs: Extra keyword arguments forwarded to the tokenizer's ``apply_chat_template`` call when prompts are rendered. Use this to toggle model-family-specific flags such as ``{"enable_thinking": False}`` for Qwen3 (suppresses ``<think>`` blocks) or ``{"thinking": True}`` for DeepSeek-V3.1 / Granite 3.2. The toolkit does not interpret the dict — it splats the keys into the call, so callers are responsible for using the kwarg name their model expects. Defaults to ``None`` (preserves prior behaviour). Applied on both serverless and classic compute paths. Returns: - ``RemoteInferenceRun`` when ``trigger_remote=True`` and ``wait=False``. - Result ``dict`` when ``trigger_remote=True`` and ``wait=True``. - ``None`` when ``trigger_remote=False``. Raises: ValueError: If validation fails (table format, GPU config, etc.). """ logger = get_logger() from yipit_databricks_client import get_dbutils dbutils = get_dbutils() is_classic = compute_type == "classic" is_anyscale = compute_type == "anyscale" is_serverless = compute_type == "serverless" # ------------------------------------------------------------------ # Validation # ------------------------------------------------------------------ from ml_toolkit.ml.llm.inference._model_utils import _is_hf_model_id _is_hf = _is_hf_model_id(model_name) if not _is_hf: _validate_three_part_name(model_name, "model_name") _validate_three_part_name(output_table, "output_table") if isinstance(input_source, str): assert_table_exists(input_source) elif isinstance(input_source, DataFrame) and trigger_remote: raise ValueError("Input source must be a table name when trigger_remote=True") if _is_hf and (model_version is not None or model_alias is not None): raise ValueError( "model_version and model_alias are not supported for " f"HuggingFace model IDs, got model_name={model_name!r}" ) if model_version is not None and model_alias is not None: raise ValueError( "Cannot specify both `model_version` and `model_alias`. " f"Got model_version={model_version!r} and model_alias={model_alias!r}." ) if model_version is not None and model_version != "latest": if not isinstance(model_version, int): raise ValueError( f"`model_version` must be an int or 'latest', got: {model_version!r}" ) if model_alias is not None: if not isinstance(model_alias, str) or not model_alias.strip(): raise ValueError( f"`model_alias` must be a non-empty string, got: {model_alias!r}" ) if not _is_hf and model_alias is not None: assert_model_alias_exists(model_name, model_alias) if not _is_hf and model_version is not None and model_version != "latest": assert_model_version_exists(model_name, model_version) gpu_type = gpu_type.lower() if is_classic: if gpu_type not in ("h100", "b200"): raise ValueError( f"Classic compute only supports H100/B200 GPUs, got gpu_type={gpu_type!r}" ) if num_gpus % 8 != 0: raise ValueError( f"num_gpus must be a multiple of 8 for classic compute, got {num_gpus}" ) if capacity_reservation_id == "auto": capacity_reservation_id = dbutils.secrets.get( "WORKSPACE_CONFIGURATION", "GPU_CAPACITY_RESERVATION_ID" ) availability_zone = dbutils.secrets.get( "WORKSPACE_CONFIGURATION", "GPU_AVAILABILITY_ZONE" ) logger.info( f"Resolved capacity block from secrets: " f"reservation={capacity_reservation_id}, az={availability_zone}" ) if capacity_reservation_id is not None and availability_zone is None: raise ValueError( "`availability_zone` is required when using " "`capacity_reservation_id`. Specify the AZ assigned to your " "Capacity Block (e.g. 'us-east-1a')." ) else: if gpu_type not in ("a10", "h100"): raise ValueError(f"gpu_type must be 'a10' or 'h100', got: {gpu_type!r}") if gpu_type == "h100" and num_gpus % 8 != 0: raise ValueError( f"num_gpus must be a multiple of 8 for H100, got {num_gpus}" ) if uc_volumes_path is None: uc_volumes_path = DEFAULT_UC_VOLUMES_PATH logger.info(f"Using uc_volumes_path: {uc_volumes_path}") # Normalize output_schema to a JSON-serializable dict so it survives # kwargs forwarding across the remote-submit boundary. The core # inference functions deserialize back to a DataType. if output_schema is not None and not isinstance(output_schema, dict): from ml_toolkit.ops.helpers.mlflow import serialize_pyspark_schema output_schema = serialize_pyspark_schema(output_schema, "output_schema") # ------------------------------------------------------------------ # Collect optional params (shared across both compute types) # ------------------------------------------------------------------ optional_params: dict = { "prompt_column": prompt_column, "model_version": model_version, "model_alias": model_alias, "system_prompt": system_prompt, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "min_p": min_p, "presence_penalty": presence_penalty, "repetition_penalty": repetition_penalty, "max_model_len": max_model_len, "concurrency": concurrency, "num_gpus": num_gpus, "gpu_type": gpu_type, "batch_size": batch_size, "tensor_parallel_size": tensor_parallel_size, "gpu_memory_utilization": gpu_memory_utilization, "max_num_seqs": max_num_seqs, "max_num_batched_tokens": max_num_batched_tokens, "kv_cache_dtype": kv_cache_dtype, "dtype": dtype, "quantization": quantization, "uc_volumes_path": uc_volumes_path, "auto_build_prompt": auto_build_prompt, "output_column": output_column, "pk_column": pk_column, "table_operation": table_operation, "ray_object_store_memory": ray_object_store_memory, "http_pool_maxsize": http_pool_maxsize, "env_vars": env_vars, "output_schema": output_schema, "chat_template_kwargs": chat_template_kwargs, } if is_classic: optional_params["ray_kwargs"] = ray_kwargs elif is_anyscale: optional_params["output_storage_location"] = output_storage_location optional_params["output_bucket"] = output_bucket if checkpoint_path is None: from ml_toolkit.ml.llm.inference.anyscale._delta_io import ( DEFAULT_OUTPUT_BUCKET, generate_checkpoint_path, ) _bucket = output_bucket or DEFAULT_OUTPUT_BUCKET _catalog, _schema, _table = output_table.split(".") checkpoint_path = generate_checkpoint_path(_bucket, _schema, _table) logger.info( f"Ray Data checkpoint path: {checkpoint_path} " f"(pass back via checkpoint_path=... to resume)" ) optional_params["checkpoint_path"] = checkpoint_path else: optional_params["uc_volumes_path"] = uc_volumes_path # ------------------------------------------------------------------ # Dispatch # ------------------------------------------------------------------ if is_anyscale: if not trigger_remote: raise ValueError( "compute_type='anyscale' always submits a remote job; " "trigger_remote=False is not supported." ) from ml_toolkit.functions.llm.inference.anyscale_function import ( submit_anyscale_inference, ) if anyscale_kwargs is None: anyscale_kwargs = AnyscaleKwargs() elif isinstance(anyscale_kwargs, dict): anyscale_kwargs = AnyscaleKwargs(**anyscale_kwargs) logger.info( f"Submitting Anyscale inference: model={model_name}, " f"input={input_source}, output={output_table}" ) remote_run = submit_anyscale_inference( input_source, model_name, output_table, compute_config=anyscale_kwargs.compute_config, image_uri=anyscale_kwargs.image_uri, job_kind=anyscale_kwargs.job_kind, team=anyscale_kwargs.team, py_modules=anyscale_kwargs.py_modules, working_dir=anyscale_kwargs.working_dir, excludes=anyscale_kwargs.excludes, requirements=anyscale_kwargs.requirements, reasoning_parser=anyscale_kwargs.reasoning_parser, max_actor_restarts=anyscale_kwargs.max_actor_restarts, max_task_retries=anyscale_kwargs.max_task_retries, **optional_params, ) return remote_run.get_result() if wait else remote_run if is_classic and trigger_remote: from ml_toolkit.functions.llm.inference.classic_function import ( submit_classic_inference, ) # The driver is a p5.48xlarge / p6-b200.48xlarge with 8 GPUs; it # runs a vLLM replica alongside coordination. ``num_gpus`` is the # total across driver + workers, so workers cover whatever is # left after the driver's 8. num_workers = max(0, (num_gpus - 8) // 8) logger.info( f"Submitting remote classic inference: model={model_name}, " f"input={input_source}, output={output_table}, " f"workers={num_workers}" ) # gpu_type is passed explicitly for cluster config selection; # also remove it from optional_params to avoid duplicate kwarg # (it's still forwarded to the runner via _build_runner_kwargs). optional_params.pop("gpu_type", None) remote_run = submit_classic_inference( input_source, model_name, output_table, gpu_type=gpu_type, num_workers=num_workers, capacity_reservation_id=capacity_reservation_id, availability_zone=availability_zone, **optional_params, ) return remote_run.get_result() if wait else remote_run elif is_serverless and trigger_remote: from ml_toolkit.functions.llm.inference.serverless_function import ( submit_serverless_inference, ) logger.info( f"Submitting remote serverless inference: model={model_name}, " f"input={input_source}, output={output_table}" ) # gpu_type/num_gpus are explicit kwargs on submit_serverless_inference; # drop them from optional_params to avoid duplicate kwargs. optional_params.pop("gpu_type", None) optional_params.pop("num_gpus", None) remote_run = submit_serverless_inference( input_source, model_name, output_table, num_gpus=num_gpus, gpu_type=gpu_type, **optional_params, ) return remote_run.get_result() if wait else remote_run elif is_classic and not trigger_remote: logger.info( f"Running local classic inference: model={model_name}, " f"input={input_source}, output={output_table}" ) _run_classic_inference_local( input_source, model_name, output_table, **optional_params, ) return None elif is_serverless and not trigger_remote: logger.info( f"Running local inference: model={model_name}, " f"input={input_source}, output={output_table}" ) _run_serverless_inference_local( input_source, model_name, output_table, **optional_params, ) return None