LLM End-to-End Workflow ======================= This guide walks through the complete lifecycle of a self-hosted LLM: from fine-tuning to GPU inference. The goal is a clean handoff between **Data Science** (model preparation) and **Data Engineering** (production inference). .. note:: **Steps 1--2** are **Data Science (DS)** responsibilities: fine-tune the model and verify registration. **Step 3** is a **Data Engineering (DE)** responsibility: run inference at scale on Databricks Serverless GPUs. Step 1 -- Fine-Tune the Model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ``run_fine_tuning`` handles the full pipeline: loads data from a Unity Catalog table, runs QLoRA training with Optuna HPO, merges the adapter into the base model, computes inference stats, and registers the result to Unity Catalog. .. code-block:: python :caption: Minimal fine-tuning call from ml_toolkit.functions.llm import run_fine_tuning REQUIREMENTS = [ "transformers==4.57.6", "datasets==4.5.0", "accelerate==1.12.0", "peft==0.18.1", "bitsandbytes==0.49.1", "safetensors==0.7.0", "threadpoolctl==3.6.0", "optuna==4.7.0", ] result = run_fine_tuning( training_table_name="catalog.schema.training_data", model_name="catalog.schema.my_ner_model", base_model="qwen2_5_3b_instruct", requirements=REQUIREMENTS, trigger_remote=True, ) print(f"Model version: {result['model_version']}") print(f"Training loss: {result['training_loss']:.4f}") **What gets registered automatically:** - Merged model weights (FP16/BF16, no quantization) - Tokenizer - Generation defaults (``temperature``, ``max_new_tokens``, ``top_p``) - Inference stats (model architecture + prompt token statistics) - Training metrics and hyperparameters **Optional: attach a prompt template.** If your training data uses a prompt from the MLflow Prompt Registry, pass it via ``DataConfig`` so it gets attached to the registered model. ``run_inference`` will auto-build the prompt column from it. .. code-block:: python from ml_toolkit.ml.llm.fine_tuning.config import DataConfig result = run_fine_tuning( training_table_name="catalog.schema.training_data", model_name="catalog.schema.my_ner_model", base_model="qwen2_5_3b_instruct", requirements=REQUIREMENTS, data_config=DataConfig( prompt_registry_name="catalog.schema.my_prompt", prompt_alias="production", ), trigger_remote=True, ) **Optional: custom input/output signature.** Pass PySpark types or DDL strings to control how Databricks Model Serving validates payloads. .. code-block:: python from pyspark.sql.types import ArrayType, StringType, StructField, StructType result = run_fine_tuning( training_table_name="catalog.schema.training_data", model_name="catalog.schema.my_ner_model", base_model="qwen2_5_3b_instruct", requirements=REQUIREMENTS, input_schema=StructType([ StructField("text", StringType(), nullable=False), StructField("context", StringType(), nullable=True), ]), output_schema=ArrayType(StringType()), trigger_remote=True, ) See the :ref:`Fine-Tuning ` section in the LLM docs for the full config reference (``LoRAConfig``, ``TrainingConfig``, ``DataConfig``, base model catalog, tier system, and prompt template options). Step 2 -- Verify Registration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ After fine-tuning completes, confirm the model is registered correctly. .. code-block:: python from ml_toolkit.ops.helpers.models import get_model_info info = get_model_info("catalog.schema.my_ner_model") info.display() Step 3 -- Run Inference ^^^^^^^^^^^^^^^^^^^^^^^^ With the model registered in steps 1--2, inference requires only the input data, model name, output table, and hardware config. Generation settings, prompt templates, and GPU batch sizing are auto-loaded from the model's ``model_config`` and ``prompt_uris``. .. code-block:: python :caption: Minimal inference call -- all parameters auto-loaded from ml_toolkit.functions.llm import run_inference run_inference( input_source="catalog.schema.input_table", model_name="catalog.schema.my_ner_model", output_table="catalog.schema.output_table", num_gpus=4, gpu_type="a10", uc_volumes_path="/Volumes/catalog/schema/vol", trigger_remote=True, ) .. code-block:: python :caption: Override auto-loaded defaults when needed run_inference( input_source="catalog.schema.input_table", model_name="catalog.schema.my_ner_model", output_table="catalog.schema.output_table", num_gpus=4, gpu_type="a10", uc_volumes_path="/Volumes/catalog/schema/vol", trigger_remote=True, # Generation overrides: temperature=0.5, max_new_tokens=256, # GPU batch overrides: batch_size=512, max_num_seqs=2048, max_num_batched_tokens=131072, max_model_len=2100, ) What Gets Auto-Loaded ^^^^^^^^^^^^^^^^^^^^^^ When a model is registered via ``run_fine_tuning`` (or ``register_LLM_model`` with ``model_config`` and ``prompt_uris``), ``run_inference`` resolves most parameters automatically. Explicit caller values **always override** defaults. .. list-table:: :header-rows: 1 :widths: 30 70 * - Parameter - Source * - ``temperature`` - ``model_config`` stored at registration time * - ``max_new_tokens`` - ``model_config`` stored at registration time * - ``top_p`` - ``model_config`` stored at registration time * - ``max_model_len`` - Estimated from model architecture + prompt token stats * - ``batch_size`` - Estimated from model architecture + prompt token stats * - ``max_num_seqs`` - Estimated from model architecture + prompt token stats * - ``max_num_batched_tokens`` - Estimated from model architecture + prompt token stats * - ``prompt_column`` - Built from ``prompt_uris`` template (if attached) * - ``dtype`` - ``preferred_dtype`` / ``serving_dtype`` from registration metadata * - ``kv_cache_dtype`` - ``model_config`` stored at registration time Inspect Results ^^^^^^^^^^^^^^^^ After inference completes, verify the output and token usage. .. code-block:: python result_df = spark.table("catalog.schema.output_table") print(f"Rows: {result_df.count()}") display(result_df.limit(10)) .. code-block:: python :caption: Aggregate token usage statistics from pyspark.sql import functions as F token_stats = spark.table("catalog.schema.output_table").agg( F.count("*").alias("total_rows"), F.sum("prompt_tokens").alias("total_prompt_tokens"), F.sum("completion_tokens").alias("total_completion_tokens"), F.avg("prompt_tokens").alias("avg_prompt_tokens"), F.avg("completion_tokens").alias("avg_completion_tokens"), ) display(token_stats)