LLM End-to-End Workflow#

This guide walks through the complete lifecycle of a self-hosted LLM: from fine-tuning to GPU inference. The goal is a clean handoff between Data Science (model preparation) and Data Engineering (production inference).

Note

Steps 1–2 are Data Science (DS) responsibilities: fine-tune the model and verify registration.

Step 3 is a Data Engineering (DE) responsibility: run inference at scale on Databricks Serverless GPUs.

Step 1 – Fine-Tune the Model#

run_fine_tuning handles the full pipeline: loads data from a Unity Catalog table, runs QLoRA training with Optuna HPO, merges the adapter into the base model, computes inference stats, and registers the result to Unity Catalog.

Minimal fine-tuning call#
from ml_toolkit.functions.llm import run_fine_tuning

REQUIREMENTS = [
    "transformers==4.57.6", "datasets==4.5.0", "accelerate==1.12.0",
    "peft==0.18.1", "bitsandbytes==0.49.1", "safetensors==0.7.0",
    "threadpoolctl==3.6.0", "optuna==4.7.0",
]

result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_ner_model",
    base_model="qwen2_5_3b_instruct",
    requirements=REQUIREMENTS,
    trigger_remote=True,
)
print(f"Model version: {result['model_version']}")
print(f"Training loss: {result['training_loss']:.4f}")

What gets registered automatically:

  • Merged model weights (FP16/BF16, no quantization)

  • Tokenizer

  • Generation defaults (temperature, max_new_tokens, top_p)

  • Inference stats (model architecture + prompt token statistics)

  • Training metrics and hyperparameters

Optional: attach a prompt template. If your training data uses a prompt from the MLflow Prompt Registry, pass it via DataConfig so it gets attached to the registered model. run_inference will auto-build the prompt column from it.

from ml_toolkit.ml.llm.fine_tuning.config import DataConfig

result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_ner_model",
    base_model="qwen2_5_3b_instruct",
    requirements=REQUIREMENTS,
    data_config=DataConfig(
        prompt_registry_name="catalog.schema.my_prompt",
        prompt_alias="production",
    ),
    trigger_remote=True,
)

Optional: custom input/output signature. Pass PySpark types or DDL strings to control how Databricks Model Serving validates payloads.

from pyspark.sql.types import ArrayType, StringType, StructField, StructType

result = run_fine_tuning(
    training_table_name="catalog.schema.training_data",
    model_name="catalog.schema.my_ner_model",
    base_model="qwen2_5_3b_instruct",
    requirements=REQUIREMENTS,
    input_schema=StructType([
        StructField("text", StringType(), nullable=False),
        StructField("context", StringType(), nullable=True),
    ]),
    output_schema=ArrayType(StringType()),
    trigger_remote=True,
)

See the Fine-Tuning section in the LLM docs for the full config reference (LoRAConfig, TrainingConfig, DataConfig, base model catalog, tier system, and prompt template options).

Step 2 – Verify Registration#

After fine-tuning completes, confirm the model is registered correctly.

from ml_toolkit.ops.helpers.models import get_model_info

info = get_model_info("catalog.schema.my_ner_model")
info.display()

Step 3 – Run Inference#

With the model registered in steps 1–2, inference requires only the input data, model name, output table, and hardware config. Generation settings, prompt templates, and GPU batch sizing are auto-loaded from the model’s model_config and prompt_uris.

Minimal inference call – all parameters auto-loaded#
from ml_toolkit.functions.llm import run_inference

run_inference(
    input_source="catalog.schema.input_table",
    model_name="catalog.schema.my_ner_model",
    output_table="catalog.schema.output_table",
    num_gpus=4,
    gpu_type="a10",
    uc_volumes_path="/Volumes/catalog/schema/vol",
    trigger_remote=True,
)
Override auto-loaded defaults when needed#
run_inference(
    input_source="catalog.schema.input_table",
    model_name="catalog.schema.my_ner_model",
    output_table="catalog.schema.output_table",
    num_gpus=4,
    gpu_type="a10",
    uc_volumes_path="/Volumes/catalog/schema/vol",
    trigger_remote=True,
    # Generation overrides:
    temperature=0.5,
    max_new_tokens=256,
    # GPU batch overrides:
    batch_size=512,
    max_num_seqs=2048,
    max_num_batched_tokens=131072,
    max_model_len=2100,
)

What Gets Auto-Loaded#

When a model is registered via run_fine_tuning (or register_LLM_model with model_config and prompt_uris), run_inference resolves most parameters automatically. Explicit caller values always override defaults.

Parameter

Source

temperature

model_config stored at registration time

max_new_tokens

model_config stored at registration time

top_p

model_config stored at registration time

max_model_len

Estimated from model architecture + prompt token stats

batch_size

Estimated from model architecture + prompt token stats

max_num_seqs

Estimated from model architecture + prompt token stats

max_num_batched_tokens

Estimated from model architecture + prompt token stats

prompt_column

Built from prompt_uris template (if attached)

dtype

preferred_dtype / serving_dtype from registration metadata

kv_cache_dtype

model_config stored at registration time

Inspect Results#

After inference completes, verify the output and token usage.

result_df = spark.table("catalog.schema.output_table")
print(f"Rows: {result_df.count()}")
display(result_df.limit(10))
Aggregate token usage statistics#
from pyspark.sql import functions as F

token_stats = spark.table("catalog.schema.output_table").agg(
    F.count("*").alias("total_rows"),
    F.sum("prompt_tokens").alias("total_prompt_tokens"),
    F.sum("completion_tokens").alias("total_completion_tokens"),
    F.avg("prompt_tokens").alias("avg_prompt_tokens"),
    F.avg("completion_tokens").alias("avg_completion_tokens"),
)
display(token_stats)