Source code for ml_toolkit.functions.llm.inference.function

"""Shared scaffolding for the functions/llm/inference module.

Backend-specific submission logic lives in ``serverless_function.py``,
``classic_function.py``, and ``anyscale_function.py``. This module hosts:

- The optional-key allowlists each backend forwards to its in-cluster runner
- ``_build_runner_kwargs`` — the single helper that turns ``run_inference``'s
  ``optional_params`` dict into a JSON-safe kwargs dict for the runner
- ``RemoteInferenceRun`` — the handle returned for serverless/classic submissions
- ``_run_*_from_dict`` — the in-cluster entrypoints invoked by ``run_remote``
- ``_run_*_local`` — the in-process paths used when ``trigger_remote=False``
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Iterable

# Path to GPU requirements file, relative to REPO_BASE_PATH.
# The backend notebook does `%pip install -r <REPO_BASE_PATH>/<this path>`
# after installing torch+cu128 and flash-attn from source.
GPU_REQUIREMENTS_PATH = "functions/llm/inference/gpu_requirements.txt"
ANYSCALE_REQUIREMENTS_PATH = "ml_toolkit/dependencies/databricks_serverless_gpus_4.txt"

# Custom notebook that handles ordered GPU package installation
# (torch → flash-attn from source → opencv-headless → remaining requirements).
# Path relative to the git repo root, without .py extension.
GPU_NOTEBOOK_PATH = "ml_toolkit/functions/llm/inference/backend"

# Parameters that are always forwarded to run_serverless_inference
_REQUIRED_KEYS = ("input_source", "model_name", "output_table")

# Optional parameters — only included in the remote kwargs when not None
_OPTIONAL_KEYS = (
    "prompt_column",
    "model_version",
    "model_alias",
    "system_prompt",
    "max_new_tokens",
    "temperature",
    "top_p",
    "top_k",
    "min_p",
    "presence_penalty",
    "repetition_penalty",
    "max_model_len",
    "num_gpus",
    "gpu_type",
    "concurrency",
    "batch_size",
    "tensor_parallel_size",
    "gpu_memory_utilization",
    "max_num_seqs",
    "max_num_batched_tokens",
    "kv_cache_dtype",
    "dtype",
    "quantization",
    "uc_volumes_path",
    "auto_build_prompt",
    "output_column",
    "pk_column",
    "table_operation",
    "ray_object_store_memory",
    "http_pool_maxsize",
    "env_vars",
    "output_schema",
    "chat_template_kwargs",
)

# Classic-specific keys forwarded to run_classic_inference
_CLASSIC_OPTIONAL_KEYS = _OPTIONAL_KEYS + ("ray_kwargs",)

# Anyscale-specific keys forwarded to run_anyscale_inference (in-cluster).
# Reuses the shared optional set + Anyscale-only fields.
_ANYSCALE_OPTIONAL_KEYS = _OPTIONAL_KEYS + (
    "output_storage_location",
    "output_bucket",
    "checkpoint_path",
    # Forwarded by caller-side ``resolve_model`` so the runner can auto-build
    # prompts without re-resolving the UC model.
    "stored_prompts",
    # vLLM engine-level reasoning parser. Routes chain-of-thought to
    # ``output.reasoning_content`` so ``output.text`` is the final answer
    # only. Auto-resolved per model family caller-side.
    "reasoning_parser",
    # Ray ActorPoolStrategy resilience knobs — see vllm_anyscale._run_ray_pipeline.
    "max_actor_restarts",
    "max_task_retries",
)


# ---------------------------------------------------------------------------
# Shared kwargs builder
# ---------------------------------------------------------------------------


def _build_runner_kwargs(
    input_source: str,
    model_name: str,
    output_table: str,
    allowed_keys: Iterable[str],
    **optional_params: Any,
) -> dict[str, Any]:
    """Build a JSON-safe kwargs dict for an in-cluster ``run_*_inference`` call.

    Always includes the three required positional args. Optional params are
    only included when their value is not ``None`` so omitted params keep
    their default (or ``_UNSET`` sentinel) inside the runner.
    """
    kwargs: dict[str, Any] = {
        "input_source": input_source,
        "model_name": model_name,
        "output_table": output_table,
    }
    for key in allowed_keys:
        value = optional_params.get(key)
        if value is not None:
            kwargs[key] = value
    return kwargs


# ---------------------------------------------------------------------------
# In-cluster entrypoints (invoked by run_remote on the remote cluster)
# ---------------------------------------------------------------------------


def _run_inference_from_dict(**kwargs: Any) -> dict[str, Any]:
    """Entrypoint for remote serverless execution via ``main_task_single.py``.

    Receives JSON-safe keyword arguments, forwards them to
    ``run_serverless_inference``, and returns a status dict suitable for
    ``dbutils.notebook.exit()``.
    """
    from ml_toolkit.ml.llm.inference.serverless.vllm_serverless import (
        run_serverless_inference,
    )

    output_table = kwargs["output_table"]
    run_serverless_inference(**kwargs)

    return {"status": "success", "output_table": output_table}


def _run_classic_inference_from_dict(**kwargs: Any) -> dict[str, Any]:
    """Entrypoint for remote classic execution via ``main_task_single.py``.

    Receives JSON-safe keyword arguments, forwards them to
    ``run_classic_inference``, and returns a status dict.
    """
    from ml_toolkit.ml.llm.inference.classic.vllm_classic import (
        run_classic_inference as _run_classic_inference_core,
    )

    output_table = kwargs["output_table"]
    _run_classic_inference_core(**kwargs)

    return {"status": "success", "output_table": output_table}


# ---------------------------------------------------------------------------
# RemoteInferenceRun handle
# ---------------------------------------------------------------------------


@dataclass
[docs] class RemoteInferenceRun: """Handle for a submitted remote inference job (serverless or classic).""" job_run_id: int databricks_url: str model_name: str output_table: str inference_config: dict[str, Any] = field(default_factory=dict) task_type: str = "gpu_inference"
[docs] def get_result(self, polling_interval: int = 60) -> dict[str, Any]: """Block until the job completes and return the result dict.""" from ml_toolkit.ops.orchestration.run_remote import RemoteJobRun inner = RemoteJobRun( job_run_id=self.job_run_id, databricks_url=self.databricks_url, task_type=self.task_type, ) result = inner.get_result(polling_interval=polling_interval) result["model_name"] = self.model_name result["output_table"] = self.output_table return result
[docs] def is_complete(self) -> bool: """Non-blocking check whether the job has finished.""" from ml_toolkit.ops.orchestration.run_remote import RemoteJobRun inner = RemoteJobRun( job_run_id=self.job_run_id, databricks_url=self.databricks_url, task_type=self.task_type, ) return inner.is_complete()
@property
[docs] def status(self) -> str: """Current lifecycle/result state of the job.""" from ml_toolkit.ops.orchestration.run_remote import RemoteJobRun inner = RemoteJobRun( job_run_id=self.job_run_id, databricks_url=self.databricks_url, task_type=self.task_type, ) return inner.status