"""Shared scaffolding for the functions/llm/inference module.
Backend-specific submission logic lives in ``serverless_function.py``,
``classic_function.py``, and ``anyscale_function.py``. This module hosts:
- The optional-key allowlists each backend forwards to its in-cluster runner
- ``_build_runner_kwargs`` — the single helper that turns ``run_inference``'s
``optional_params`` dict into a JSON-safe kwargs dict for the runner
- ``RemoteInferenceRun`` — the handle returned for serverless/classic submissions
- ``_run_*_from_dict`` — the in-cluster entrypoints invoked by ``run_remote``
- ``_run_*_local`` — the in-process paths used when ``trigger_remote=False``
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Iterable
# Path to GPU requirements file, relative to REPO_BASE_PATH.
# The backend notebook does `%pip install -r <REPO_BASE_PATH>/<this path>`
# after installing torch+cu128 and flash-attn from source.
GPU_REQUIREMENTS_PATH = "functions/llm/inference/gpu_requirements.txt"
ANYSCALE_REQUIREMENTS_PATH = "ml_toolkit/dependencies/databricks_serverless_gpus_4.txt"
# Custom notebook that handles ordered GPU package installation
# (torch → flash-attn from source → opencv-headless → remaining requirements).
# Path relative to the git repo root, without .py extension.
GPU_NOTEBOOK_PATH = "ml_toolkit/functions/llm/inference/backend"
# Parameters that are always forwarded to run_serverless_inference
_REQUIRED_KEYS = ("input_source", "model_name", "output_table")
# Optional parameters — only included in the remote kwargs when not None
_OPTIONAL_KEYS = (
"prompt_column",
"model_version",
"model_alias",
"system_prompt",
"max_new_tokens",
"temperature",
"top_p",
"top_k",
"min_p",
"presence_penalty",
"repetition_penalty",
"max_model_len",
"num_gpus",
"gpu_type",
"concurrency",
"batch_size",
"tensor_parallel_size",
"gpu_memory_utilization",
"max_num_seqs",
"max_num_batched_tokens",
"kv_cache_dtype",
"dtype",
"quantization",
"uc_volumes_path",
"auto_build_prompt",
"output_column",
"pk_column",
"table_operation",
"ray_object_store_memory",
"http_pool_maxsize",
"env_vars",
"output_schema",
"chat_template_kwargs",
)
# Classic-specific keys forwarded to run_classic_inference
_CLASSIC_OPTIONAL_KEYS = _OPTIONAL_KEYS + ("ray_kwargs",)
# Anyscale-specific keys forwarded to run_anyscale_inference (in-cluster).
# Reuses the shared optional set + Anyscale-only fields.
_ANYSCALE_OPTIONAL_KEYS = _OPTIONAL_KEYS + (
"output_storage_location",
"output_bucket",
"checkpoint_path",
# Forwarded by caller-side ``resolve_model`` so the runner can auto-build
# prompts without re-resolving the UC model.
"stored_prompts",
# vLLM engine-level reasoning parser. Routes chain-of-thought to
# ``output.reasoning_content`` so ``output.text`` is the final answer
# only. Auto-resolved per model family caller-side.
"reasoning_parser",
# Ray ActorPoolStrategy resilience knobs — see vllm_anyscale._run_ray_pipeline.
"max_actor_restarts",
"max_task_retries",
)
# ---------------------------------------------------------------------------
# Shared kwargs builder
# ---------------------------------------------------------------------------
def _build_runner_kwargs(
input_source: str,
model_name: str,
output_table: str,
allowed_keys: Iterable[str],
**optional_params: Any,
) -> dict[str, Any]:
"""Build a JSON-safe kwargs dict for an in-cluster ``run_*_inference`` call.
Always includes the three required positional args. Optional params are
only included when their value is not ``None`` so omitted params keep
their default (or ``_UNSET`` sentinel) inside the runner.
"""
kwargs: dict[str, Any] = {
"input_source": input_source,
"model_name": model_name,
"output_table": output_table,
}
for key in allowed_keys:
value = optional_params.get(key)
if value is not None:
kwargs[key] = value
return kwargs
# ---------------------------------------------------------------------------
# In-cluster entrypoints (invoked by run_remote on the remote cluster)
# ---------------------------------------------------------------------------
def _run_inference_from_dict(**kwargs: Any) -> dict[str, Any]:
"""Entrypoint for remote serverless execution via ``main_task_single.py``.
Receives JSON-safe keyword arguments, forwards them to
``run_serverless_inference``, and returns a status dict suitable for
``dbutils.notebook.exit()``.
"""
from ml_toolkit.ml.llm.inference.serverless.vllm_serverless import (
run_serverless_inference,
)
output_table = kwargs["output_table"]
run_serverless_inference(**kwargs)
return {"status": "success", "output_table": output_table}
def _run_classic_inference_from_dict(**kwargs: Any) -> dict[str, Any]:
"""Entrypoint for remote classic execution via ``main_task_single.py``.
Receives JSON-safe keyword arguments, forwards them to
``run_classic_inference``, and returns a status dict.
"""
from ml_toolkit.ml.llm.inference.classic.vllm_classic import (
run_classic_inference as _run_classic_inference_core,
)
output_table = kwargs["output_table"]
_run_classic_inference_core(**kwargs)
return {"status": "success", "output_table": output_table}
# ---------------------------------------------------------------------------
# RemoteInferenceRun handle
# ---------------------------------------------------------------------------
@dataclass
[docs]
class RemoteInferenceRun:
"""Handle for a submitted remote inference job (serverless or classic)."""
job_run_id: int
databricks_url: str
model_name: str
output_table: str
inference_config: dict[str, Any] = field(default_factory=dict)
task_type: str = "gpu_inference"
[docs]
def get_result(self, polling_interval: int = 60) -> dict[str, Any]:
"""Block until the job completes and return the result dict."""
from ml_toolkit.ops.orchestration.run_remote import RemoteJobRun
inner = RemoteJobRun(
job_run_id=self.job_run_id,
databricks_url=self.databricks_url,
task_type=self.task_type,
)
result = inner.get_result(polling_interval=polling_interval)
result["model_name"] = self.model_name
result["output_table"] = self.output_table
return result
[docs]
def is_complete(self) -> bool:
"""Non-blocking check whether the job has finished."""
from ml_toolkit.ops.orchestration.run_remote import RemoteJobRun
inner = RemoteJobRun(
job_run_id=self.job_run_id,
databricks_url=self.databricks_url,
task_type=self.task_type,
)
return inner.is_complete()
@property
[docs]
def status(self) -> str:
"""Current lifecycle/result state of the job."""
from ml_toolkit.ops.orchestration.run_remote import RemoteJobRun
inner = RemoteJobRun(
job_run_id=self.job_run_id,
databricks_url=self.databricks_url,
task_type=self.task_type,
)
return inner.status