"""User-facing interface for GPU inference via ``run_remote()``.
Supports both **serverless** and **classic** (Ray-on-Spark) compute via
the ``compute_type`` parameter.
Example (serverless — default)::
from ml_toolkit.functions.llm import run_inference
remote_run = run_inference(
"catalog.schema.input_table",
"catalog.schema.my_model",
"catalog.schema.output",
uc_volumes_path="/Volumes/catalog/schema/vol",
)
print(remote_run.databricks_url)
result = remote_run.get_result()
Example (classic p5.48xlarge with Capacity Blocks)::
remote_run = run_inference(
"catalog.schema.input_table",
"catalog.schema.my_model",
"catalog.schema.output",
compute_type="classic",
gpu_type="h100",
num_gpus=16, # → driver (8 GPUs) + 1 worker (8 GPUs)
capacity_reservation_id="cr-0abc1234def56789a",
availability_zone="us-east-1a",
)
"""
from __future__ import annotations
from typing import Literal, Optional
from pyspark.sql import DataFrame
from yipit_databricks_client.helpers.telemetry import track_usage
from ml_toolkit.functions.llm.inference.anyscale_function import AnyscaleKwargs
from ml_toolkit.functions.llm.inference.classic_function import (
_run_classic_inference_local,
)
from ml_toolkit.functions.llm.inference.function import RemoteInferenceRun
from ml_toolkit.functions.llm.inference.serverless_function import (
_run_serverless_inference_local,
)
from ml_toolkit.ml.llm.inference.serverless.vllm_serverless import (
DEFAULT_UC_VOLUMES_PATH,
)
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.validation import (
assert_model_alias_exists,
assert_model_version_exists,
assert_table_exists,
)
def _get_installed_version(package: str) -> str | None:
"""Return the installed version of *package*, or ``None`` if missing."""
from importlib.metadata import PackageNotFoundError, version
try:
return version(package)
except PackageNotFoundError:
return None
def _version_satisfies(installed: str | None, spec: str) -> bool:
"""Check whether *installed* version satisfies a PEP 440 *spec*."""
if installed is None:
return False
from packaging.specifiers import SpecifierSet
return installed in SpecifierSet(spec)
# -- Target versions ---------------------------------------------------------
_TORCH_VERSION = "==2.9.0"
_FLASH_ATTN_VERSION = "==2.8.3"
_RAY_VERSION = "==2.47.1"
_VLLM_VERSION = "==0.13.0"
_HF_TRANSFER_VERSION = "==0.1.9"
_EINOPS_VERSION = "==0.8.2"
_OPENCV_VERSION = "==4.12.0.88"
_PYARROW_VERSION = "==21"
_CUDA_PYTHON_VERSION = ">=12,<13"
def install_inference_gpu_requirements() -> None:
"""Install GPU inference dependencies, skipping what's already satisfied.
Auto-detects whether the current environment is **serverless** or
**classic** (standard) via ``detect_compute_environment()``.
On classic clusters (e.g. DBR 18.1 ML), many packages are
pre-installed (torch, flash-attn, transformers). This function
checks installed versions and only installs what's missing or
outdated — avoiding the expensive torch reinstall and flash-attn
source compilation when not needed.
On serverless environments, the full install is performed since
packages are not pre-installed.
Uses the IPython ``%pip`` magic (via ``get_ipython().run_line_magic``)
instead of raw ``subprocess`` so that Databricks tracks environment
changes for worker node propagation and environment snapshotting.
Install order (when needed):
1. Base requirements (from ``gpu_requirements.txt``)
2. ``torch`` — skipped on classic if already ``>=2.9.0``
3. Build tooling + ``flash-attn`` — skipped if already satisfied
4. ``opencv-python-headless`` (OpenSSL FIPS fix)
5. ``ray[data]``, ``hf_transfer``, ``einops``
6. ``vllm``
"""
from IPython import get_ipython
from ml_toolkit.ops.helpers.environment import detect_compute_environment
env = detect_compute_environment()
is_serverless = env == "serverless"
logger = get_logger()
logger.info(f"Detected compute environment: {env}")
_ipy = get_ipython()
def _pip(*args: str) -> None:
_ipy.run_line_magic("pip", "install " + " ".join(args))
# -- 1. Base requirements (always) ----------------------------------------
import os
from ml_toolkit.functions.llm.inference.function import GPU_REQUIREMENTS_PATH
from ml_toolkit.ops.settings import REPO_BASE_PATH
req_file = os.path.join(REPO_BASE_PATH, GPU_REQUIREMENTS_PATH)
if os.path.exists(req_file):
_pip("-r", req_file)
# -- 2. torch -------------------------------------------------------------
# Serverless: always install cu128 build (clean env).
# Classic: skip if a compatible torch is already installed (e.g.
# DBR 18.1 ships torch 2.9.0+cu129 which works with vLLM).
torch_installed = _get_installed_version("torch")
torch_ok = not is_serverless and _version_satisfies(torch_installed, _TORCH_VERSION)
if torch_ok:
logger.info(
f"torch {torch_installed} already satisfies {_TORCH_VERSION}, skipping"
)
else:
logger.info("Installing torch+cu128")
_pip(
"--no-cache-dir",
"torch==2.9.0+cu128",
"--index-url",
"https://download.pytorch.org/whl/cu128",
)
# -- 3. flash-attn --------------------------------------------------------
# Expensive source compilation (~5-10 min). Skip if already installed
# at the target version (e.g. DBR 18.1 ships flash-attn 2.8.3
# pre-compiled against its torch build).
flash_installed = _get_installed_version("flash-attn")
flash_ok = _version_satisfies(flash_installed, _FLASH_ATTN_VERSION)
if flash_ok:
logger.info(
f"flash-attn {flash_installed} already satisfies {_FLASH_ATTN_VERSION}, skipping"
)
else:
logger.info("Installing build tooling + flash-attn from source")
_pip(
"-U", "--no-cache-dir", "wheel==0.46.3", "ninja==1.13.0", "packaging==26.0"
)
_pip(
"--force-reinstall",
"--no-cache-dir",
"--no-build-isolation",
"--no-deps",
"flash-attn==2.8.3",
)
# -- 4. opencv-python-headless (OpenSSL FIPS fix) -------------------------
opencv_installed = _get_installed_version("opencv-python-headless")
if not _version_satisfies(opencv_installed, _OPENCV_VERSION):
_pip("--no-cache-dir", "opencv-python-headless==4.12.0.88")
# -- 5. ray, hf_transfer, einops -----------------------------------------
# Ray 2.47.1 for all compute types (Databricks RayPatches incompatible
# with BlockMetadata changes in Ray >=2.48).
_ray_spec = _RAY_VERSION
_ray_pin = "ray[data]==2.47.1"
to_install = []
ray_installed = _get_installed_version("ray")
if not _version_satisfies(ray_installed, _ray_spec):
to_install.append(_ray_pin)
else:
logger.info(f"ray {ray_installed} already satisfies {_ray_spec}, skipping")
if not _version_satisfies(
_get_installed_version("hf-transfer"), _HF_TRANSFER_VERSION
):
to_install.append("hf_transfer==0.1.9")
if not _version_satisfies(_get_installed_version("einops"), _EINOPS_VERSION):
to_install.append("einops==0.8.2")
if to_install:
_pip(*to_install)
# -- 6. vLLM --------------------------------------------------------------
vllm_installed = _get_installed_version("vllm")
if not _version_satisfies(vllm_installed, _VLLM_VERSION):
_pip("vllm==0.13.0")
else:
logger.info(
f"vllm {vllm_installed} already satisfies {_VLLM_VERSION}, skipping"
)
# -- 7. Re-pin ray after vllm (vllm may pull a newer ray as transitive dep)
ray_after_vllm = _get_installed_version("ray")
if not _version_satisfies(ray_after_vllm, _RAY_VERSION):
logger.info(
f"ray drifted to {ray_after_vllm} after vllm install, "
f"re-pinning to {_ray_pin}"
)
_pip(_ray_pin)
# -- 8. Classic-only: pyarrow + cuda-python --------------------------------
# pyarrow 21 is needed for Arrow-based Spark↔Ray data exchange on
# PySpark 4.x (DBR 18.1+). cuda-python provides CUDA bindings that
# vLLM requires on classic clusters (serverless bundles them).
if not is_serverless:
pyarrow_installed = _get_installed_version("pyarrow")
if not _version_satisfies(pyarrow_installed, _PYARROW_VERSION):
_pip("--force-reinstall", "--no-deps", "pyarrow==21")
cuda_python_installed = _get_installed_version("cuda-python")
if not _version_satisfies(cuda_python_installed, _CUDA_PYTHON_VERSION):
_pip('"cuda-python>=12,<13"')
logger.info("GPU inference requirements ready")
def _validate_three_part_name(value: str, label: str) -> None:
"""Raise ValueError if *value* is not a three-part dotted name."""
parts = value.split(".")
if len(parts) != 3 or not all(parts):
raise ValueError(
f"`{label}` must be in the format `catalog.schema.name`, got: {value!r}"
)
@track_usage
[docs]
def run_inference(
input_source: str | DataFrame,
model_name: str,
output_table: str,
*,
prompt_column: str = "prompt",
model_version: Optional[int | Literal["latest"]] = None,
model_alias: Optional[str] = None,
system_prompt: Optional[str] = None,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
presence_penalty: Optional[float] = 0.0,
repetition_penalty: Optional[float] = 1.0,
max_model_len: Optional[int] = None,
num_gpus: int = 4,
gpu_type: str = "a10",
concurrency: Optional[int] = None,
batch_size: Optional[int] = None,
tensor_parallel_size: int = 1,
gpu_memory_utilization: Optional[float] = 0.85,
max_num_seqs: Optional[int] = None,
max_num_batched_tokens: Optional[int] = None,
kv_cache_dtype: Optional[str] = None,
dtype: Optional[str] = None,
quantization: Optional[str] = None,
uc_volumes_path: Optional[str] = None,
auto_build_prompt: bool = True,
output_column: str = "model_output",
pk_column: Optional[str] = None,
table_operation: Optional[Literal["overwrite", "append"]] = "overwrite",
ray_object_store_memory: float = 0.5,
http_pool_maxsize: int = 64,
trigger_remote: bool = True,
wait: bool = False,
compute_type: Literal["serverless", "classic", "anyscale"] = "serverless",
capacity_reservation_id: Optional[str] = "auto",
availability_zone: Optional[str] = "auto",
ray_kwargs: Optional[dict] = None,
env_vars: Optional[dict[str, str]] = None,
output_schema: Optional[object] = None,
chat_template_kwargs: Optional[dict] = None,
output_storage_location: Optional[str] = None,
output_bucket: Optional[str] = None,
checkpoint_path: Optional[str] = None,
anyscale_kwargs: "Optional[AnyscaleKwargs | dict]" = None,
) -> RemoteInferenceRun | dict | None:
"""Run vLLM batch inference on Databricks GPU compute.
Supports two compute backends via ``compute_type``:
- ``"serverless"`` (default) — Databricks Serverless GPU. Uses A10 or
H100 accelerators. Requires ``uc_volumes_path`` for data staging.
- ``"classic"`` — Databricks classic cluster with p5.48xlarge (8x H100)
or p6-b200.48xlarge (8x B200) nodes. Uses Ray-on-Spark for distributed
inference. GPUs in multiples of 8. The driver contributes 8 GPUs; the
number of additional worker nodes is auto-computed as
``max(0, (num_gpus - 8) // 8)``. Supports AWS Capacity Blocks.
When ``trigger_remote=True`` (default), the function validates inputs,
serializes parameters, and submits a remote job via ``run_remote()``.
When ``trigger_remote=False``, inference runs directly on the current
cluster (must already have GPUs available).
Args:
input_source: Fully qualified UC table name (``catalog.schema.table``)
or DataFrame.
model_name: UC model name (``catalog.schema.model``).
output_table: Fully qualified output table name.
prompt_column: Column containing prompt text. Defaults to ``"prompt"``.
model_version: Specific model version (int) or ``"latest"``.
Mutually exclusive with ``model_alias``. ``None`` = production alias.
model_alias: UC alias (e.g. ``"champion"``). Mutually exclusive with
``model_version``.
system_prompt: System instruction for chat completions.
max_new_tokens: Maximum tokens to generate per prompt.
temperature: Sampling temperature (0.0 = greedy).
top_p: Nucleus sampling threshold.
top_k: Controls the number of top tokens to consider. Set to 0 (or -1) to
consider all tokens.
min_p: Represents the minimum probability for a token to be considered,
relative to the probability of the most likely token. Must be in [0, 1].
presence_penalty: Penalizes new tokens based on whether they appear in the generated text
so far. Values > 0 encourage the model to use new tokens, while values < 0
encourage the model to repeat tokens.
repetition_penalty: Penalizes new tokens based on whether they appear in the prompt and the
generated text so far. Values > 1 encourage the model to use new tokens,
while values < 1 encourage the model to repeat tokens.
max_model_len: Maximum sequence length.
num_gpus: Number of GPUs. Defaults to 4. For ``compute_type="classic"``,
must be a multiple of 8 (H100 only) and counts the total across
driver + workers. ``num_gpus=8`` runs driver-only;
``num_gpus=16`` runs driver + 1 worker node.
gpu_type: GPU type — ``"a10"``, ``"h100"``, or ``"b200"``. Defaults
to ``"a10"``. For ``compute_type="classic"``, must be ``"h100"``
or ``"b200"``.
concurrency: Number of vLLM data-parallel replicas.
batch_size: Rows per inference batch.
tensor_parallel_size: GPUs per vLLM instance.
gpu_memory_utilization: Fraction of GPU memory to use for vLLM.
max_num_seqs: Maximum concurrent sequences per vLLM engine.
max_num_batched_tokens: Maximum tokens per vLLM iteration step.
kv_cache_dtype: Data type for KV cache (e.g. ``"fp8"``).
dtype: Model weight data type for vLLM (e.g. ``"bfloat16"``,
``"float16"``, ``"auto"``). ``None`` = auto-resolved from model config.
quantization: vLLM quantization method.
``"auto"`` (default) = resolved from model_config → GPU defaults
(``"fp8"`` for H100/B200, ``None`` for A10).
``"fp8"`` = dynamically quantize weights to FP8_E4M3 at load time.
``None`` = don't pass quantization to vLLM (auto-detect from model).
uc_volumes_path: UC Volume path for staging data. Required.
auto_build_prompt: Auto-build prompt column from model template.
output_column: Name of the output column containing model predictions.
Defaults to ``"model_output"``.
pk_column: Name of an existing column to use as the row identifier for
joining inference results back to the input. If ``None`` (default),
a synthetic row ID is generated automatically. Use this when your
input already has a unique primary key column to avoid creating a
redundant one.
table_operation: How to write the output table — ``"overwrite"`` (default)
replaces the table, ``"append"`` adds rows to an existing table.
ray_object_store_memory: Fraction of cluster memory for Ray object
store. Defaults to 0.5.
http_pool_maxsize: Max concurrent HTTP connections for model download.
trigger_remote: If ``True``, submit via ``run_remote()``. If ``False``,
run directly on current cluster.
wait: If ``True`` and ``trigger_remote=True``, block until completion
and return the result dict.
compute_type: ``"serverless"`` (default) or ``"classic"``.
capacity_reservation_id: AWS Capacity Block reservation ID
(e.g. ``"cr-0abc1234def56789a"``). Classic only. Configures the
cluster with on-demand instances and the
``X-Databricks-AwsCapacityBlockId`` tag. Pass ``"auto"`` to
fetch from the ``WORKSPACE_CONFIGURATION`` secret scope
(keys ``GPU_CAPACITY_RESERVATION_ID`` and
``GPU_AVAILABILITY_ZONE``).
availability_zone: AWS availability zone for the Capacity Block
(e.g. ``"us-east-1a"``). Required when using
``capacity_reservation_id`` (auto-resolved when ``"auto"``).
ray_kwargs: Dict of Ray-on-Spark tuning parameters forwarded to
``run_classic_inference`` (classic only). Supported keys:
``num_cpus_per_node``, ``num_cpus_head_node``,
``num_gpus_head_node``.
env_vars: Optional dictionary of environment variables to set in
each Ray worker process before inference starts. Ray actors do
not inherit the driver's environment, so diagnostic variables
like ``VLLM_LOGGING_LEVEL=DEBUG`` must be forwarded explicitly.
Defaults to ``None`` (no extra variables).
output_schema: Optional PySpark ``DataType`` describing the JSON
shape the model is expected to produce (e.g.
``ArrayType(StructType([...]))``). When set, the output table
gets a parsed ``model_output`` struct column alongside the raw
text in ``model_output_raw``. Accepts a ``DataType`` instance,
a DDL string (e.g. ``"array<struct<l1_category:string,...>>"``),
or the dict form produced by
:func:`ml_toolkit.ops.helpers.mlflow.serialize_pyspark_schema`.
Required for HuggingFace model IDs (no MLflow signature); for
UC models, overrides whatever the MLflow signature declared.
chat_template_kwargs: Extra keyword arguments forwarded to the
tokenizer's ``apply_chat_template`` call when prompts are
rendered. Use this to toggle model-family-specific flags such
as ``{"enable_thinking": False}`` for Qwen3 (suppresses
``<think>`` blocks) or ``{"thinking": True}`` for
DeepSeek-V3.1 / Granite 3.2. The toolkit does not interpret
the dict — it splats the keys into the call, so callers are
responsible for using the kwarg name their model expects.
Defaults to ``None`` (preserves prior behaviour). Applied on
both serverless and classic compute paths.
Returns:
- ``RemoteInferenceRun`` when ``trigger_remote=True`` and ``wait=False``.
- Result ``dict`` when ``trigger_remote=True`` and ``wait=True``.
- ``None`` when ``trigger_remote=False``.
Raises:
ValueError: If validation fails (table format, GPU config, etc.).
"""
logger = get_logger()
from yipit_databricks_client import get_dbutils
dbutils = get_dbutils()
is_classic = compute_type == "classic"
is_anyscale = compute_type == "anyscale"
is_serverless = compute_type == "serverless"
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
from ml_toolkit.ml.llm.inference._model_utils import _is_hf_model_id
_is_hf = _is_hf_model_id(model_name)
if not _is_hf:
_validate_three_part_name(model_name, "model_name")
_validate_three_part_name(output_table, "output_table")
if isinstance(input_source, str):
assert_table_exists(input_source)
elif isinstance(input_source, DataFrame) and trigger_remote:
raise ValueError("Input source must be a table name when trigger_remote=True")
if _is_hf and (model_version is not None or model_alias is not None):
raise ValueError(
"model_version and model_alias are not supported for "
f"HuggingFace model IDs, got model_name={model_name!r}"
)
if model_version is not None and model_alias is not None:
raise ValueError(
"Cannot specify both `model_version` and `model_alias`. "
f"Got model_version={model_version!r} and model_alias={model_alias!r}."
)
if model_version is not None and model_version != "latest":
if not isinstance(model_version, int):
raise ValueError(
f"`model_version` must be an int or 'latest', got: {model_version!r}"
)
if model_alias is not None:
if not isinstance(model_alias, str) or not model_alias.strip():
raise ValueError(
f"`model_alias` must be a non-empty string, got: {model_alias!r}"
)
if not _is_hf and model_alias is not None:
assert_model_alias_exists(model_name, model_alias)
if not _is_hf and model_version is not None and model_version != "latest":
assert_model_version_exists(model_name, model_version)
gpu_type = gpu_type.lower()
if is_classic:
if gpu_type not in ("h100", "b200"):
raise ValueError(
f"Classic compute only supports H100/B200 GPUs, got gpu_type={gpu_type!r}"
)
if num_gpus % 8 != 0:
raise ValueError(
f"num_gpus must be a multiple of 8 for classic compute, got {num_gpus}"
)
if capacity_reservation_id == "auto":
capacity_reservation_id = dbutils.secrets.get(
"WORKSPACE_CONFIGURATION", "GPU_CAPACITY_RESERVATION_ID"
)
availability_zone = dbutils.secrets.get(
"WORKSPACE_CONFIGURATION", "GPU_AVAILABILITY_ZONE"
)
logger.info(
f"Resolved capacity block from secrets: "
f"reservation={capacity_reservation_id}, az={availability_zone}"
)
if capacity_reservation_id is not None and availability_zone is None:
raise ValueError(
"`availability_zone` is required when using "
"`capacity_reservation_id`. Specify the AZ assigned to your "
"Capacity Block (e.g. 'us-east-1a')."
)
else:
if gpu_type not in ("a10", "h100"):
raise ValueError(f"gpu_type must be 'a10' or 'h100', got: {gpu_type!r}")
if gpu_type == "h100" and num_gpus % 8 != 0:
raise ValueError(
f"num_gpus must be a multiple of 8 for H100, got {num_gpus}"
)
if uc_volumes_path is None:
uc_volumes_path = DEFAULT_UC_VOLUMES_PATH
logger.info(f"Using uc_volumes_path: {uc_volumes_path}")
# Normalize output_schema to a JSON-serializable dict so it survives
# kwargs forwarding across the remote-submit boundary. The core
# inference functions deserialize back to a DataType.
if output_schema is not None and not isinstance(output_schema, dict):
from ml_toolkit.ops.helpers.mlflow import serialize_pyspark_schema
output_schema = serialize_pyspark_schema(output_schema, "output_schema")
# ------------------------------------------------------------------
# Collect optional params (shared across both compute types)
# ------------------------------------------------------------------
optional_params: dict = {
"prompt_column": prompt_column,
"model_version": model_version,
"model_alias": model_alias,
"system_prompt": system_prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": min_p,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
"max_model_len": max_model_len,
"concurrency": concurrency,
"num_gpus": num_gpus,
"gpu_type": gpu_type,
"batch_size": batch_size,
"tensor_parallel_size": tensor_parallel_size,
"gpu_memory_utilization": gpu_memory_utilization,
"max_num_seqs": max_num_seqs,
"max_num_batched_tokens": max_num_batched_tokens,
"kv_cache_dtype": kv_cache_dtype,
"dtype": dtype,
"quantization": quantization,
"uc_volumes_path": uc_volumes_path,
"auto_build_prompt": auto_build_prompt,
"output_column": output_column,
"pk_column": pk_column,
"table_operation": table_operation,
"ray_object_store_memory": ray_object_store_memory,
"http_pool_maxsize": http_pool_maxsize,
"env_vars": env_vars,
"output_schema": output_schema,
"chat_template_kwargs": chat_template_kwargs,
}
if is_classic:
optional_params["ray_kwargs"] = ray_kwargs
elif is_anyscale:
optional_params["output_storage_location"] = output_storage_location
optional_params["output_bucket"] = output_bucket
if checkpoint_path is None:
from ml_toolkit.ml.llm.inference.anyscale._delta_io import (
DEFAULT_OUTPUT_BUCKET,
generate_checkpoint_path,
)
_bucket = output_bucket or DEFAULT_OUTPUT_BUCKET
_catalog, _schema, _table = output_table.split(".")
checkpoint_path = generate_checkpoint_path(_bucket, _schema, _table)
logger.info(
f"Ray Data checkpoint path: {checkpoint_path} "
f"(pass back via checkpoint_path=... to resume)"
)
optional_params["checkpoint_path"] = checkpoint_path
else:
optional_params["uc_volumes_path"] = uc_volumes_path
# ------------------------------------------------------------------
# Dispatch
# ------------------------------------------------------------------
if is_anyscale:
if not trigger_remote:
raise ValueError(
"compute_type='anyscale' always submits a remote job; "
"trigger_remote=False is not supported."
)
from ml_toolkit.functions.llm.inference.anyscale_function import (
submit_anyscale_inference,
)
if anyscale_kwargs is None:
anyscale_kwargs = AnyscaleKwargs()
elif isinstance(anyscale_kwargs, dict):
anyscale_kwargs = AnyscaleKwargs(**anyscale_kwargs)
logger.info(
f"Submitting Anyscale inference: model={model_name}, "
f"input={input_source}, output={output_table}"
)
remote_run = submit_anyscale_inference(
input_source,
model_name,
output_table,
compute_config=anyscale_kwargs.compute_config,
image_uri=anyscale_kwargs.image_uri,
job_kind=anyscale_kwargs.job_kind,
team=anyscale_kwargs.team,
py_modules=anyscale_kwargs.py_modules,
working_dir=anyscale_kwargs.working_dir,
excludes=anyscale_kwargs.excludes,
requirements=anyscale_kwargs.requirements,
reasoning_parser=anyscale_kwargs.reasoning_parser,
max_actor_restarts=anyscale_kwargs.max_actor_restarts,
max_task_retries=anyscale_kwargs.max_task_retries,
**optional_params,
)
return remote_run.get_result() if wait else remote_run
if is_classic and trigger_remote:
from ml_toolkit.functions.llm.inference.classic_function import (
submit_classic_inference,
)
# The driver is a p5.48xlarge / p6-b200.48xlarge with 8 GPUs; it
# runs a vLLM replica alongside coordination. ``num_gpus`` is the
# total across driver + workers, so workers cover whatever is
# left after the driver's 8.
num_workers = max(0, (num_gpus - 8) // 8)
logger.info(
f"Submitting remote classic inference: model={model_name}, "
f"input={input_source}, output={output_table}, "
f"workers={num_workers}"
)
# gpu_type is passed explicitly for cluster config selection;
# also remove it from optional_params to avoid duplicate kwarg
# (it's still forwarded to the runner via _build_runner_kwargs).
optional_params.pop("gpu_type", None)
remote_run = submit_classic_inference(
input_source,
model_name,
output_table,
gpu_type=gpu_type,
num_workers=num_workers,
capacity_reservation_id=capacity_reservation_id,
availability_zone=availability_zone,
**optional_params,
)
return remote_run.get_result() if wait else remote_run
elif is_serverless and trigger_remote:
from ml_toolkit.functions.llm.inference.serverless_function import (
submit_serverless_inference,
)
logger.info(
f"Submitting remote serverless inference: model={model_name}, "
f"input={input_source}, output={output_table}"
)
# gpu_type/num_gpus are explicit kwargs on submit_serverless_inference;
# drop them from optional_params to avoid duplicate kwargs.
optional_params.pop("gpu_type", None)
optional_params.pop("num_gpus", None)
remote_run = submit_serverless_inference(
input_source,
model_name,
output_table,
num_gpus=num_gpus,
gpu_type=gpu_type,
**optional_params,
)
return remote_run.get_result() if wait else remote_run
elif is_classic and not trigger_remote:
logger.info(
f"Running local classic inference: model={model_name}, "
f"input={input_source}, output={output_table}"
)
_run_classic_inference_local(
input_source,
model_name,
output_table,
**optional_params,
)
return None
elif is_serverless and not trigger_remote:
logger.info(
f"Running local inference: model={model_name}, "
f"input={input_source}, output={output_table}"
)
_run_serverless_inference_local(
input_source,
model_name,
output_table,
**optional_params,
)
return None