text_classification module#

The text_classification module provides functions for supervised training of text classification models, including loading pre-trained HuggingFace models and tokenizers with built-in retry logic and local caching.

Loading Models#

Use load_model_with_retry and load_auto_tokenizer_with_retry to download and load HuggingFace models. These functions cache models locally so subsequent loads skip the download step.

For gated or private models (e.g., Llama), pass your HuggingFace token via hf_token:

Loading a gated model with an explicit token#
from ml_toolkit.ops.helpers.models import load_model_with_retry, load_auto_tokenizer_with_retry
from transformers import AutoModelForSequenceClassification

hf_token = dbutils.secrets.get(scope="ml", key="hf-token")

model = load_model_with_retry(
    AutoModelForSequenceClassification,
    "meta-llama/Llama-2-7b",
    hf_token=hf_token,
)

tokenizer = load_auto_tokenizer_with_retry(
    "meta-llama/Llama-2-7b",
    hf_token=hf_token,
)

Note

When hf_token is None (the default), huggingface_hub automatically falls back to the HF_TOKEN environment variable. If your cluster already has HF_TOKEN set, you can omit the parameter entirely.

Loading a public model (no token required)#
from ml_toolkit.ops.helpers.models import load_model_with_retry
from transformers import AutoModelForSequenceClassification

model = load_model_with_retry(
    AutoModelForSequenceClassification,
    "microsoft/deberta-v3-small",
)