From f6c8ec92ce94c44f142a93d8c6d13ada0b8c5ad9 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 2 Jun 2026 18:30:15 +0000 Subject: [PATCH] Refactor llm_qat example: YAML configs + ModelOptArgParser Refactor the examples/llm_qat example to use YAML configs driven by a shared ModelOptArgParser (HfArgumentParser wrapper), introduce ModelOptHFArguments dataclasses, a reusable QuantizationArguments dataclass, and refactored dataset/tokenization utilities. - Align ModelArguments.model_name_or_path default with the README (Qwen/Qwen3-8B) and document the intentional ChatML assistant-mask exclusion of trailing <|im_end|> and the post-header newline. - Add a CPU-only tiny gpt-oss SFT e2e test (test_gpt_oss_sft_toy) plus a tiny_gpt_oss_path fixture, validating that sft.py's TrlParser parses the shared QuantizationArguments incl. the new recipe field; move the release marker onto the heavy GPU pipeline test so the toy test runs in regular CI. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: realAsma --- .pre-commit-config.yaml | 17 +- CHANGELOG.rst | 4 + examples/llm_qad/README.md | 2 + examples/llm_qat/.gitignore | 2 + examples/llm_qat/ARGUMENTS.md | 51 ++ examples/llm_qat/README.md | 464 +++++----- examples/llm_qat/accelerate_config/fsdp1.yaml | 29 - examples/llm_qat/arguments.py | 139 +++ .../accelerate}/ddp.yaml | 0 .../accelerate}/deepspeed.yaml | 0 .../accelerate}/fsdp2.yaml | 2 +- examples/llm_qat/configs/dataset/README.md | 144 +++ examples/llm_qat/configs/dataset/blend.yaml | 45 + .../configs/dataset/blend_example.yaml | 61 ++ .../llm_qat/configs/dataset/blend_test.yaml | 15 + examples/llm_qat/configs/train/finetune.yaml | 37 + examples/llm_qat/configs/train/qad_nvfp4.yaml | 42 + examples/llm_qat/configs/train/qat_nvfp4.yaml | 38 + .../llm_qat/configs/train/qlora_nvfp4.yaml | 41 + examples/llm_qat/dataset_utils.py | 831 ++++++++++++++++++ examples/llm_qat/launch.sh | 179 ---- examples/llm_qat/llama_factory/README.md | 26 +- .../llama_factory/launch_llamafactory.sh | 16 +- .../llm_qat/llama_factory/llama_config.yaml | 2 +- .../llm_qat/llama_factory/llama_factory.py | 6 +- examples/llm_qat/main.py | 272 ------ .../notebooks/QAT_QAD_Walkthrough.ipynb | 26 +- examples/llm_qat/quantize.py | 106 +++ examples/llm_qat/requirements.txt | 2 +- examples/llm_qat/simple_qat_train.py | 44 +- examples/llm_qat/train.py | 153 ++++ examples/llm_qat/utils.py | 154 +--- examples/vllm_serve/README.md | 2 +- modelopt/torch/distill/plugins/huggingface.py | 39 + modelopt/torch/opt/plugins/transformers.py | 176 +++- .../plugins/transformers_trainer.py | 124 ++- .../ptq/int4_blockwise_weight_only.yaml | 62 ++ tests/examples/conftest.py | 29 +- tests/examples/gpt_oss/test_gpt_oss_qat.py | 71 +- tests/examples/llm_qat/test_assistant_mask.py | 61 ++ .../llm_qat/test_dataset_tokenization.py | 195 ++++ tests/examples/llm_qat/test_llm_qat.py | 180 ++-- .../opt/plugins/test_modelopt_arg_parser.py | 122 +++ 43 files changed, 2981 insertions(+), 1030 deletions(-) create mode 100644 examples/llm_qat/.gitignore create mode 100644 examples/llm_qat/ARGUMENTS.md delete mode 100644 examples/llm_qat/accelerate_config/fsdp1.yaml create mode 100644 examples/llm_qat/arguments.py rename examples/llm_qat/{accelerate_config => configs/accelerate}/ddp.yaml (100%) rename examples/llm_qat/{accelerate_config => configs/accelerate}/deepspeed.yaml (100%) rename examples/llm_qat/{accelerate_config => configs/accelerate}/fsdp2.yaml (91%) create mode 100644 examples/llm_qat/configs/dataset/README.md create mode 100644 examples/llm_qat/configs/dataset/blend.yaml create mode 100644 examples/llm_qat/configs/dataset/blend_example.yaml create mode 100644 examples/llm_qat/configs/dataset/blend_test.yaml create mode 100644 examples/llm_qat/configs/train/finetune.yaml create mode 100644 examples/llm_qat/configs/train/qad_nvfp4.yaml create mode 100644 examples/llm_qat/configs/train/qat_nvfp4.yaml create mode 100644 examples/llm_qat/configs/train/qlora_nvfp4.yaml create mode 100644 examples/llm_qat/dataset_utils.py delete mode 100755 examples/llm_qat/launch.sh delete mode 100644 examples/llm_qat/main.py create mode 100644 examples/llm_qat/quantize.py create mode 100644 examples/llm_qat/train.py create mode 100644 modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml create mode 100644 tests/examples/llm_qat/test_assistant_mask.py create mode 100644 tests/examples/llm_qat/test_dataset_tokenization.py create mode 100644 tests/unit/torch/opt/plugins/test_modelopt_arg_parser.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab3b8c99b77..71fb769ef31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -109,7 +109,7 @@ repos: examples/llm_eval/lm_eval_hf.py| examples/llm_eval/mmlu.py| examples/llm_eval/modeling.py| - examples/llm_qat/main.py| + examples/llm_qat/train.py| examples/llm_sparsity/weight_sparsity/finetune.py| examples/specdec_bench/specdec_bench/models/specbench_medusa.py| examples/speculative_decoding/main.py| @@ -137,6 +137,21 @@ repos: args: ["-c", "pyproject.toml", "-q"] additional_dependencies: ["bandit[toml]"] + - repo: local + hooks: + - id: generate-arguments-md + name: Regenerate examples/llm_qat/ARGUMENTS.md + entry: bash -c 'python examples/llm_qat/arguments.py --generate_docs examples/llm_qat/ARGUMENTS.md' + language: system + files: >- + (?x)^( + examples/llm_qat/arguments\.py| + modelopt/torch/distill/plugins/huggingface\.py| + modelopt/torch/opt/plugins/transformers\.py| + modelopt/torch/quantization/plugins/transformers_trainer\.py + )$ + pass_filenames: false + - repo: https://github.com/DavidAnson/markdownlint-cli2 rev: v0.18.1 hooks: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b6a3a979dd5..364a010ddc6 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -43,6 +43,10 @@ Changelog - Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache. - Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). - Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md `_ for details. +- Refactor ``llm_qat`` example with unified YAML-based configuration and flexible dataset blending. + ``ModelOptArgParser`` adds ``--config`` YAML support with CLI overrides and auto-generates ``ARGUMENTS.md`` from dataclass definitions. + Dataset blending (``configs/dataset/blend.yaml``) supports HuggingFace datasets, local JSON/JSONL/Parquet files, and weighted multi-source blends. + The legacy FSDP1 accelerate config is removed; ``llm_qat`` now documents FSDP2, DeepSpeed, and DDP backends. **Bug Fixes** diff --git a/examples/llm_qad/README.md b/examples/llm_qad/README.md index 68fd0184909..0d86d7b77f7 100644 --- a/examples/llm_qad/README.md +++ b/examples/llm_qad/README.md @@ -2,6 +2,8 @@ Quantization-Aware Distillation (QAD) training scripts for language models using Megatron-LM. These scripts enable training quantized (e.g., NVFP4) student models with knowledge distillation from full-precision teacher models. +> **Note:** For Hugging Face LLM QAD, see the [LLM QAT QAD section](../llm_qat/README.md#end-to-end-qad-example). + ## Overview | Script | Purpose | diff --git a/examples/llm_qat/.gitignore b/examples/llm_qat/.gitignore new file mode 100644 index 00000000000..a013cb22d3e --- /dev/null +++ b/examples/llm_qat/.gitignore @@ -0,0 +1,2 @@ +.cache/ +.dataset_cache/ diff --git a/examples/llm_qat/ARGUMENTS.md b/examples/llm_qat/ARGUMENTS.md new file mode 100644 index 00000000000..6adb050b10f --- /dev/null +++ b/examples/llm_qat/ARGUMENTS.md @@ -0,0 +1,51 @@ +# Argument Reference + + + +## DistillArguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--distill` | `bool` | `False` | Enable training with knowledge distillation. | +| `--teacher_model` | `str` | `None` | The name or path of the teacher model to use for distillation. | +| `--criterion` | `str` | `"logits_loss"` | Distillation loss criterion. Currently only 'logits_loss' is supported. | + +## DataArguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--dataset_config` | `str` | `"configs/dataset/blend.yaml"` | Path to a dataset blend YAML config file. | +| `--train_samples` | `int` | `20000` | Number of training samples to use. | +| `--eval_samples` | `int` | `2000` | Number of evaluation samples to use. | +| `--dataset_seed` | `int` | `42` | Random seed for dataset shuffling. | +| `--dataset_cache_dir` | `str` | `".dataset_cache/tokenized"` | Directory for caching tokenized datasets. | +| `--shuffle` | `bool` | `True` | Whether to shuffle dataset sources (reservoir sampling). | +| `--shuffle_buffer` | `int` | `10000` | Buffer size for streaming shuffle. | +| `--num_proc` | `int` | `16` | Number of CPU workers for tokenization. | + +## ModelArguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--model_name_or_path` | `str` | `"Qwen/Qwen3-8B"` | HuggingFace model name or local path to the base model to quantize/train. | +| `--model_max_length` | `int` | `4096` | Maximum sequence length. Sequences will be right-padded (and possibly truncated). | + +## QuantizeArguments + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--recipe` | `str` | `None` | Path to a quantization recipe YAML file (built-in or custom). Built-in recipes can be specified by relative path, e.g. 'general/ptq/nvfp4_default-kv_fp8'. Replaces the deprecated --quant_cfg flag. | +| `--quant_cfg` | `modelopt.torch.quantization.config.QuantizeConfig` | `None` | Deprecated: pre-quantize the model with a separate quantization step instead. Specify the quantization format for PTQ/QAT by name (e.g. NVFP4_DEFAULT_CFG). | +| `--calib_size` | `int` | `512` | Specify the calibration size for quantization. The calibration dataset is used to setup the quantization scale parameters for PTQ/QAT. | +| `--compress` | `bool` | `False` | Whether to compress the model weights after quantization for QLoRA. This is useful for reducing the model size. | +| `--calib_batch_size` | `int` | `1` | Batch size for calibration data during quantization. | +| `--output_dir` | `str` | `"quantized_model"` | Directory to save the quantized model checkpoint. | + +## TrainingArguments + +Extends [HuggingFace TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Only additional arguments are shown below. + +| Argument | Type | Default | Description | +|----------|------|---------|-------------| +| `--cache_dir` | `str` | `None` | | +| `--lora` | `bool` | `False` | Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, the LoRA adapter must be set, as quantized weights will be frozen during training. | diff --git a/examples/llm_qat/README.md b/examples/llm_qat/README.md index 1f2bf8458b4..1c3db55d077 100644 --- a/examples/llm_qat/README.md +++ b/examples/llm_qat/README.md @@ -1,363 +1,339 @@ -# Quantization Aware Training (QAT) +# Quantization Aware Training (QAT) and Distillation (QAD) -Quantization Aware Training (QAT) helps to improve the model accuracy beyond post training quantization (PTQ). QAT can further preserve model accuracy at low precisions (e.g., INT4, or FP4 in [NVIDIA Blackwell platform](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/)). +Quantization Aware Training (QAT) improves model accuracy beyond post-training quantization (PTQ) at low precisions (e.g., INT4, FP4 on [NVIDIA Blackwell](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/)). Quantization Aware Distillation (QAD) further improves accuracy by using the original full-precision model as a teacher. + +For background on how QAT enables low-precision accuracy recovery, see the [QAT/QAD blog post](https://developer.nvidia.com/blog/how-quantization-aware-training-enables-low-precision-accuracy-recovery/).
| **Section** | **Description** | **Link** | **Docs** | -| :------------: | :------------: | :------------: | :------------: | -| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | -| Getting Started | Learn how to optimize your models using QAT to reduce precision and improve model accuracy post quantization | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | -| Support Matrix | View the support matrix to see quantization compatibility and feature availability across different models | \[[Link](#support-matrix)\] | | -| End to End QAT | Example scripts demonstrating quantization techniques for optimizing Hugging Face models | \[[Link](#end-to-end-qat-example)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | -| End to End QAD | Example scripts demonstrating quantization aware distillation techniques for optimizing Hugging Face models | \[[Link](#end-to-end-qad-example)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | -| Evaluate Accuracy | Evaluating model accuracy after QAT/QAD (with fake quantization) | \[[Link](#testing-qat-model-with-llm-benchmarks-for-accuracy-evaluation)\] | | -| Deployment | Deploying the model after QAT/QAD | \[[Link](#deployment)\] | | -| QLoRA | Model training with reduced GPU memory | \[[Link](#end-to-end-qlora-with-real-quantization)\] | | -| Pre-Quantized Checkpoints | Ready to deploy Hugging Face pre-quantized checkpoints | \[[Link](#pre-quantized-checkpoints)\] | | -| Resources | Extra links to relevant resources | \[[Link](#resources)\] | | +| :---: | :---: | :---: | :---: | +| Quick Start | Prerequisites and setup | \[[Link](#quick-start)\] | | +| End-to-End Example | Run QAT/QAD in 3 steps: quantize, train, export | \[[Link](#run-end-to-end-qatqad-example)\] | | +| Background | How QAT/QAD work and when to use each | \[[Link](#background)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | +| Support Matrix | Supported models, quantization formats, and backends | \[[Link](#support-matrix)\] | | +| QLoRA | Model training with reduced GPU memory | \[[Link](#qlora-real-quantization)\] | | +| Advanced Topics | FSDP2 config, YAML options | \[[Link](#advanced-topics)\] | | +| Results | Accuracy benchmarks | \[[Link](#results)\] | | +| Resources | Extra links and references | \[[Link](#resources)\] | |
-## Pre-Requisites - -Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#pre-requisites) for the pre-requisites. - -## Getting Started +## Quick Start -In QAT, a model quantized using [mtq.quantize()](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.quantize) can be directly fine-tuned with the original training pipeline. During QAT, the scaling factors inside quantizers are frozen and the model weights are fine-tuned. +### Prerequisites -To learn more about the QAT feature, please refer to the [documentation](https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#quantization-aware-training-qat). +Please refer to [llm_ptq/README.md](../llm_ptq/README.md#pre-requisites) for prerequisites. -Quantization aware distillation (QAD) can be used to further improve accuracy of the model using the original full precision model as a teacher model in cases where QAT is not enough. +The Qwen3-8B example below requires a minimum of **2 x 80GB GPUs**. -### Hugging Face QAT / QAD +## Run End-to-End QAT/QAD Example -> **_NOTE:_** In this example, a QAT and QAD workflow is demonstrated for Huggingface text generation model for supervised fine-tuning (SFT). However, the workflow is general and can be extended to frameworks such as [Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) and models beyond LLMs such as CNN-based vision models. +All arguments can be specified via YAML config, CLI flags, or both (CLI overrides YAML). See [Advanced Topics](#advanced-topics). -#### System Requirements +### QAT -The Llama3-8B fine-tuning and QAT below requires a minimum of 2 x 80GB GPUs per machine. +Quantize, fine-tune on labeled data, and export: -#### QAT Example Workflow +```sh +# 1. Quantize +python quantize.py \ + --model_name_or_path Qwen/Qwen3-8B \ + --dataset_config configs/dataset/blend.yaml \ + --recipe general/ptq/nvfp4_default-kv_fp8 \ + --output_dir qwen3-8b-quantized + +# 2. Train +accelerate launch --config-file configs/accelerate/fsdp2.yaml train.py \ + --config configs/train/qat_nvfp4.yaml \ + --model_name_or_path qwen3-8b-quantized \ + --output_dir qwen3-8b-qat-nvfp4 + +# 3. Export +python export.py --pyt_ckpt_path qwen3-8b-qat-nvfp4 --export_path qwen3-8b-qat-deploy +``` -In QAT, a model quantized using [mtq.quantize()](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.quantize) can be directly fine-tuned with the original training pipeline. During QAT, the scaling factors inside quantizers are frozen and the model weights are fine-tuned. +### QAD -Here is the recommended QAT workflow: +Quantize, recover accuracy using the original model as teacher, and export: -Step 1: Train/fine-tune the model in the original precision without quantization. +```sh +# 1. Quantize +python quantize.py \ + --model_name_or_path Qwen/Qwen3-8B \ + --dataset_config configs/dataset/blend.yaml \ + --recipe general/ptq/nvfp4_default-kv_fp8 \ + --output_dir qwen3-8b-quantized + +# 2. Train with distillation +accelerate launch --config-file configs/accelerate/fsdp2.yaml train.py \ + --config configs/train/qad_nvfp4.yaml \ + --model_name_or_path qwen3-8b-quantized \ + --teacher_model Qwen/Qwen3-8B \ + --output_dir qwen3-8b-qad-nvfp4 + +# 3. Export +python export.py --pyt_ckpt_path qwen3-8b-qad-nvfp4 --export_path qwen3-8b-qad-deploy +``` -Step 2: Quantize the model from step 1 with `mtq.quantize()` +Exported checkpoints can be deployed on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm), or [SGLang](https://github.com/sgl-project/sglang). See [llm_ptq/README.md](../llm_ptq/README.md#deployment) for deployment instructions. For quick accuracy evaluation without exporting, see [Native Fake-Quantized Evaluation](#native-fake-quantized-evaluation). -Step 3: Train/fine-tune the quantized model with a small learning rate, e.g. 1e-5 for Adam optimizer. +> **Note:** To see the full QAT flow in a single script (quantize + train + save), see [simple_qat_train.py](simple_qat_train.py): +> +> ```sh +> python simple_qat_train.py --model-path meta-llama/Llama-3.2-3B --recipe general/ptq/nvfp4_default-kv_fp8 +> ``` -> **_NOTE:_** `Step 3` listed above is the actual 'Quantization Aware Training' step. The optimal hyperparameter setting for QAT can vary depending on the model and training dataset. The optimal QAT duration depends on the dataset, model etc. +## Background -> **_NOTE:_** We find QAT without the original precision training/fine-tuning (i.e skipping `Step 1` of the QAT workflow from above) to give worse accuracy. Therefore, we recommend un-quantized original precision training/fine-tuning followed by QAT for best accuracy. +### What is QAT? -> **_NOTE:_** Huggingface models trained with `modelopt.torch.speculative` (mtsp) can be used in QAT directly like regular Huggingface models. +**Quantization Aware Training (QAT)** inserts simulated quantization operations into the model graph and then fine-tunes the model so its weights learn to compensate for quantization error. During training, quantization scales are frozen while weights are updated. QAT is a general technique — it learns from labeled data on a quantized model. ```python -import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.recipe import load_recipe -... +# 1. Load a quantization recipe +recipe = load_recipe("general/ptq/nvfp4_default-kv_fp8") -# [Not shown] load model, tokenizer, data loaders etc -trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module) +# 2. Quantize the model in-place +model = mtq.quantize(model, recipe.quantize, forward_loop) +# 3. Fine-tune the quantized model +trainer.train() +trainer.save_model() +``` -def forward_loop(model): - for i, data in enumerate(calib_dataloader): - model(data) +> ModelOpt provides accelerated quantization kernels using Triton for NVFP4 QAT. See the [installation guide](https://nvidia.github.io/Model-Optimizer/getting_started/_installation_for_Linux.html#accelerated-quantization-with-triton-kernels). +### What is QAD? -# Quantize the model in-place; The model should be unwrapped from any distributed wrapper -model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop) +**Quantization Aware Distillation (QAD)** is a special case of QAT that uses a teacher model (typically the original unquantized model) to guide the quantized student via a distillation loss. QAD is a **pure accuracy recovery technique** — its goal is to recover accuracy lost from quantization, not to teach the model a new task. -# Save the modelopt quantizer states -torch.save(mto.modelopt_state(model), "modelopt_quantizer_states.pt") +To learn more, read the [QAT/QAD blog post](https://developer.nvidia.com/blog/how-quantization-aware-training-enables-low-precision-accuracy-recovery/). -# To resume training from a checkpoint or load the final QAT model for evaluation, -# load the quantizer states before loading the model weights -# mto.restore_from_modelopt_state(model, modelopt_state_path="modelopt_quantizer_states.pt") -# After loading the quantizer states, load the model weights -# model.load_state_dict(state_dict_from_last_checkpoint) +### When to Use QAT vs QAD -trainer.train() # Train the quantized model (i.e, QAT) +| | **QAT** (without distillation) | **QAD** (with distillation) | +|-|---------|----------------------| +| **What it does** | Fine-tunes a quantized model on labeled data | Recovers quantization accuracy using the original model as teacher | +| **When to use** | The model is already quantized and you want to fine-tune it for a **new task** (e.g., fine-tuning a [GPT-OSS](../gpt-oss/) quantized checkpoint) | You want the **best possible accuracy recovery** after quantization | +| **Recommended workflow** | Start from a quantized checkpoint, fine-tune with task-specific data | Full-precision fine-tuning first, then QAD to recover quantization loss | -# Save the final model weights; An example usage -trainer.save_model() -``` +**QAD is Model Optimizer's recommended strategy for accuracy recovery after quantization.** In our experiments, full-precision fine-tuning followed by QAD delivers the best accuracy, especially at aggressive quantization levels (e.g., NVFP4). The optimal balance between QAT and QAD for a given model and task is an active area of research. -> **_NOTE:_** The example above uses [mto.modelopt_state](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.opt.conversion.html#modelopt.torch.opt.conversion.modelopt_state) and [mto.restore_from_modelopt_state](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.opt.conversion.html#modelopt.torch.opt.conversion.restore_from_modelopt_state) for saving and restoring of ModelOpt -> modified model. ModelOpt provides additional methods/workflows for saving and restoring ModelOpt modified model. Please see [saving & restoring](https://nvidia.github.io/Model-Optimizer/guides/2_save_load.html) to learn more. +### Using `QATTrainer` and `QADTrainer` -> **_NOTE:_** ModelOpt provides accelerated quantization kernels using Triton that significantly speed up NVFP4 format QAT. For details, see the [installation guide](https://nvidia.github.io/Model-Optimizer/getting_started/_installation_for_Linux.html#accelerated-quantization-with-triton-kernels). +`QATTrainer` is a drop-in replacement for HuggingFace's `Trainer` that handles quantization-aware training seamlessly with various distributed backends (FSDP2, DeepSpeed, DDP): -A simple QAT training example can be found in [simple_qat_train.py](simple_qat_train.py). It can train the model using a single GPU on [Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset. To run: +```python +from modelopt.torch.quantization.plugins.transformers_trainer import QATTrainer -```sh -python simple_qat_train.py --model meta-llama/Llama-3.2-3B +trainer = QATTrainer( + model=model, # pre-quantized model + processing_class=tokenizer, + args=training_args, + **data_module, +) +trainer.train() +trainer.save_model() ``` -To train larger models with distributed training, please refer to [End-to-end QAT Example](#end-to-end-qat-example). - -#### QAD Example Workflow - -Here is an example workflow for performing QAD: - -> **_NOTE:_** QAD workflow is experimental and is subject to change. +`QADTrainer` extends `QATTrainer` with distillation: ```python -import modelopt.torch.opt as mto -import modelopt.torch.distill as mtd -import modelopt.torch.quantization as mtq from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer - -... - -# [Not shown] load model, tokenizer, data loaders etc -# Create the distillation config distill_config = { - "teacher_model": teacher_model, - "criterion": LMLogitsLoss(), + "teacher_model": teacher_model, + "criterion": LMLogitsLoss(), } trainer = QADTrainer( - model=model, - processing_class=tokenizer, - args=training_args, - quant_args=quant_args, - distill_config=distill_config, - **data_module, + model=model, # pre-quantized model + processing_class=tokenizer, + args=training_args, + distill_config=distill_config, + **data_module, ) - -trainer.train() # Train the quantized model using distillation (i.e, QAD) - -# Save the final student model weights; An example usage +trainer.train() trainer.save_model() ``` -## Support Matrix - -### Model Support List - -This script supports the following models out of the box. +### Quantization Recipes -| Model | Support | -| :---: | :---: | -| LLAMA 2 | ✅ | -| LLAMA 3, 3.1 | ✅ | -| CodeLlama | ✅ | -| Qwen2, 2.5, 3 dense models | ✅ | +Recipes are declarative YAML files that specify the quantization configuration. Built-in recipes are available in [`modelopt_recipes/`](../../modelopt_recipes/): -### Supported quantization configuration for QAT +```sh +# List available built-in recipes +ls modelopt_recipes/general/ptq/ +``` -Current quantization configs can be found [here](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/quantization/config.py). +See [custom calibration](https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#advanced-configuration-creation) for creating your own recipe. -These are the recommended quantization configurations for QAT: +## Support Matrix -```python -import modelopt.torch.quantization as mtq +### Supported Models -mtq.INT8_DEFAULT_CFG # INT8 Per-channel weight with INT8 per-tensor activation quantization -mtq.FP8_DEFAULT_CFG # FP8 per-tensor weight & activation quantization -mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG # FP8 2D blockwise weightly only quantization -mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG # FP8 per channel weight with per-token activation quantization -mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG # INT4 blockwise weight only quantization -mtq.NVFP4_DEFAULT_CFG # NVFP4 dynamic block weight & activation quantization -mtq.MXFP8_DEFAULT_CFG # MXFP8 per-tensor weight and activation quantization -``` +| Model | Chat Template | Support | +|-------|---------------|---------| +| Qwen2, 2.5, 3, 3.5 dense models; Nemotron ChatML models | ChatML | Yes (chat + assistant-only labels + pretrain) | +| Models with `{% generation %}` chat templates | Model-specific | Yes (chat + assistant-only labels + pretrain) | +| Other models with HuggingFace chat templates, including Llama 2, 3, 3.1 | Model-specific | Yes (chat full-label + pretrain) | -You can also create your own custom config using [this](https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#custom-calibration-algorithm) guide. +> **Note:** `apply_chat_template` controls chat formatting. `train_only_assistant_tokens` controls label masking: `auto` uses assistant-only labels when native `{% generation %}` masks or the tested Qwen/Nemotron ChatML heuristic is available, then falls back to all non-padding chat-template tokens; set `train_only_assistant_tokens: true` to require native or ChatML assistant-only labels, or `false` to always train on all chat-template tokens. -## End-to-end QAT Example +### Supported Quantization Formats -This folder contains end-to-end runnable fine-tuning/QAT pipeline where Llama3-8B from huggingface is trained on -[Daring-Anteater](https://huggingface.co/datasets/nvidia/Daring-Anteater) dataset. +| Format | Precision | Recipe | Use Case | +|--------|-----------|--------|----------| +| **NVFP4** | W4A4 + FP8 KV | `general/ptq/nvfp4_default-kv_fp8` | Maximum compression for Blackwell GPUs | +| **FP8** | W8A8 + FP8 KV | `general/ptq/fp8_default-fp8_kv` | Balanced speed and accuracy | +| **INT4** weight-only | W4A16 | `general/ptq/int4_blockwise_weight_only` | Deployable on all Ampere or later GPUs | -First, we need to run un-quantized fine-tuning. Here is the command for that: +> **NVFP4** uses 4-bit FP weights and activations (E2M1 with FP8 dynamic scales) plus FP8 KV cache. Partial variants are available for quantizing only specific layers (e.g., MLP-only, MoE experts-only) — see [`modelopt_recipes/general/ptq/`](../../modelopt_recipes/general/ptq/) for all options. -```sh -./launch.sh --model meta-llama/Meta-Llama-3-8B \ - --num_epochs 2.0 \ - --lr 1e-5 \ - --do_train True \ - --output_dir llama3-finetune -``` +### Supported Backends -This will generate a fine-tuned checkpoint in `output_dir` specified above. You can load this checkpoint, quantize the model, evaluate PTQ results or run additional QAT. -This can be accomplished by specifying the quantization format to the `launch.sh` script. -In this example, we are quantizing the model with INT4 block-wise weights and INT8 per-tensor activation quantization. +| Backend | Config File | Notes | +|---------|------------|-------| +| FSDP2 | `configs/accelerate/fsdp2.yaml` | **Recommended** | +| DDP | `configs/accelerate/ddp.yaml` | Add `--gradient_checkpointing True` | +| DeepSpeed | `configs/accelerate/deepspeed.yaml` | Add `--gradient_checkpointing True` | -To perform PTQ evaluation, run: +Replace `--config-file configs/accelerate/fsdp2.yaml` with the desired backend config in any of the commands above. -```sh -# Load the checkpoint from previous fine-tuning stage, quantize the model and evaluate without additional training -./launch.sh --model llama3-finetune \ - --do_train False \ - --quant_cfg NVFP4_DEFAULT_CFG -``` +## QLoRA (Real Quantization) -To perform QAT, run: +[QLoRA](https://arxiv.org/pdf/2305.14314) reduces training memory by quantizing LoRA backbone weights with real quantization via `mtq.compress()`. ```sh -# Load the quantized checkpoint from previous fine-tuning stage and run additional training (QAT) -./launch.sh --model llama3-finetune \ - --num_epochs 2.0 \ - --lr 1e-5 \ - --do_train True \ - --quant_cfg NVFP4_DEFAULT_CFG \ - --output_dir llama3-qat -``` - -You may alternatively perform QAT with any other quantization formats from **ModelOpt**. Please see more details on the supported quantization formats and how to use them as shown below: - -```python -import modelopt.torch.quantization as mtq +# 1. Quantize with compression +python quantize.py \ + --model_name_or_path Qwen/Qwen3-8B \ + --dataset_config configs/dataset/blend.yaml \ + --recipe general/ptq/nvfp4_default-kv_fp8 \ + --compress True \ + --output_dir qwen3-8b-quantized + +# 2. Train with QLoRA +accelerate launch --config-file configs/accelerate/ddp.yaml train.py \ + --config configs/train/qlora_nvfp4.yaml \ + --model_name_or_path qwen3-8b-quantized \ + --output_dir qwen3-8b-fp4-qlora + +# 3. Export +python export.py \ + --pyt_ckpt_path qwen3-8b-fp4-qlora \ + --export_path qwen3-8b-fp4-qlora-hf -# To learn about the quantization formats and quantization config from modelopt -help(mtq.config) +# 4. Serve with vLLM +vllm serve qwen3-8b-fp4-qlora-hf/base_model --enable-lora \ + --lora-modules adapter=qwen3-8b-fp4-qlora-hf --port 8000 \ + --tokenizer qwen3-8b-fp4-qlora-hf ``` -You could also add your own customized quantization format to `CUSTOM_QUANT_CFG` from `main.py` and perform QAT. - -> **_NOTE:_** QAT requires higher memory than the full-precision fine-tuning. A solution to avoid this extra memory usage is to use [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) or gradient checkpointing. Activation checkpointing can be enabled easily with training frameworks such as Huggingface by adding an additional argument `gradient_checkpointing True`. Learn more [here](https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one#gradient-checkpointing). Activation checkpointing or gradient checkpointing is enabled by default in this example. +> QLoRA export is not currently supported with FSDP2. -> **_NOTE:_** Like any other model training, the QAT model accuracy can be further improved by optimizing the training -> hyper-parameters such as learning rate, training duration etc. +## Advanced Topics -> **_NOTE:_** `launch.sh` defaults to use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different class, you need to pass `--fsdp_transformer_layer_cls_to_wrap ` to the `launch.sh` script. For example, for `Qwen/Qwen3-8B`, specify `--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer` as an additional argument. +
+FSDP2 and Model-Specific Layer Wrapping -### Results +The default `fsdp2.yaml` uses `TRANSFORMER_BASED_WRAP` with `fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer`. This setting is **model-specific** — if you are training a different model architecture, you must update it to match your model's decoder layer class. -Here is an example result following the workflow above with slightly different hyper-parameters (We used an effective batch size of 128 by adjusting `--train_bs` and `--accum_steps` as per the available GPU memory). -As we can see below, QAT has improved the validation perplexity. +You can either: -You could get slightly different numbers depending on your hyper-parameters - however you should be able to see consistent improvement -for QAT over PTQ alone. +1. **Override via CLI** (recommended for one-off runs): -| | Validation perplexity on `nvidia/Daring-Anteater` dataset | -|-----------------|--------------------| -| Fine-tuned BF16 (No quantization) | 1.45 | -| PTQ with NVFP4 weights & NVFP4 activations on the Fine-tuned BF16 model | 1.56 | -| QAT with NVFP4 weights & NVFP4 activations | 1.49 | + ```sh + accelerate launch --config-file configs/accelerate/fsdp2.yaml \ + --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \ + train.py --config configs/train/qat_nvfp4.yaml ... + ``` -> **_NOTE:_** From our experience, the QAT performs better with a larger batch size, so we recommend using a larger batch size if your hardware allows it. +2. **Create a custom config** (recommended for repeated use): -> **_NOTE:_** If you only use part of the dataset for fine-tuning/QAT, we recommend to use different data samples for fine-tuning and QAT, otherwise there may appear overfitting issues during the QAT stage. + ```sh + cp configs/accelerate/fsdp2.yaml configs/accelerate/fsdp2_llama.yaml + # Edit fsdp2_llama.yaml: change Qwen3DecoderLayer -> LlamaDecoderLayer + ``` -## End-to-end QAD Example +Common layer class names: -To perform QAD with logits loss, run: +| Model Family | `fsdp_transformer_layer_cls_to_wrap` | +|---|---| +| Qwen2, Qwen2.5, Qwen3 | `Qwen3DecoderLayer` (or `Qwen2DecoderLayer`) | +| Llama 2, 3, 3.1 | `LlamaDecoderLayer` | -```sh -./launch.sh --model llama3-finetune \ - --num_epochs 3 \ - --lr 4e-5 \ - --quant_cfg NVFP4_DEFAULT_CFG \ - --do_train True \ - --output_dir llama-qad \ - --distill True -``` +
-> **_NOTE:_** QAD doesn't support FSDP1 () backend - only FSDP2. +
+Configuration -## Testing QAT model with LLM benchmarks for accuracy evaluation +There are two types of configs: -The model generated after QAT can be tested for LLM accuracy evaluation for various LLM benchmarks. After running the fine-tuning, following code can be used to run LLM evaluation for [supported tasks](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks). +- **Dataset configs** (`configs/dataset/`): Define the dataset blend — sources, `blend_size` (total samples), and `splits` (train/eval/test ratios). These are self-contained and determine what gets cached. +- **Training configs** (`configs/train/`): Define training hyperparameters plus runtime caps (`train_samples`, `eval_samples`) that subset the pre-built dataset without retriggering caching. -To run the llm_eval tasks on QAT model, run: +`quantize.py` only needs `--dataset_config` and `--recipe`. `train.py` uses a full training config via `--config`. All arguments can be specified via YAML, CLI flags, or both (CLI overrides YAML). See [ARGUMENTS.md](ARGUMENTS.md) for the full reference, regenerated with `python_pwd examples/llm_qat/arguments.py --generate_docs examples/llm_qat/ARGUMENTS.md`. ```sh -cd ../llm_eval - -python lm_eval_hf.py --model hf \ - --tasks \ - --model_args pretrained=../llm_qat/llama3-qat \ - --quant_cfg NVFP4_DEFAULT_CFG \ - --batch_size 4 +# YAML + CLI override +accelerate launch --config-file configs/accelerate/fsdp2.yaml train.py \ + --config configs/train/qat_nvfp4.yaml --learning_rate 5e-5 ``` -See more details on running LLM evaluation benchmarks [here](../llm_eval/README.md). +See [Dataset Configuration](configs/dataset/README.md) for custom dataset blends and adding new datasets. -## Deployment +
-The final model after QAT/QAD is similar in architecture to that of PTQ model. QAT model simply have updated weights as compared to the PTQ model. It can be deployed to TensorRT-LLM (TRTLLM)/TensorRT/vLLM/SGLang just like a regular **ModelOpt** PTQ model if the quantization format is supported for deployment. +
+Pre-Building the Dataset -To export TRTLLM/vLLM/SGLang compatible checkpoint for the model after QAT (or QAD) model, run: +You can pre-tokenize and cache the dataset before training using `dataset_utils.py`. This is useful for large blends or multi-node setups where you want to build the cache once and reuse it across experiments. ```sh -python export.py --pyt_ckpt_path llama3-qat --export_path llama3-qat-deploy +python dataset_utils.py \ + --dataset_config configs/dataset/blend.yaml \ + --model_name_or_path Qwen/Qwen3-8B ``` -Note: The QAT checkpoint for `w4a8_awq` config can be created by using `--quant_cfg W4A8_AWQ_BETA_CFG` in [QAT example](#end-to-end-qat-example). +The cached dataset is stored under `.dataset_cache/tokenized/` by default (configurable via `--dataset_cache_dir`). The cache key depends on the dataset config (`blend_size`, `splits`, sources) and tokenizer — changing `train_samples` or `eval_samples` in the training config does **not** invalidate the cache. -See more details on deployment of quantized model [here](../llm_ptq/README.md). +
-## End-to-end QLoRA with Real Quantization +## Results -[QLoRA](https://arxiv.org/pdf/2305.14314) is a technique mainly intended for further reducing the training memory requirement of LoRA. In QLoRA, the LoRA backbone weights are quantized to reduce the model footprint. Unlike QAT which uses simulated quantization, QLoRA requires real quantization. To compress the model weights after quantization, we use the `mtq.compress()` function, which currently supports FP8, FP4, and INT4 formats. This feature can be enabled by passing `--compress True` to the `launch.sh` script. For detailed configuration options and patterns, please refer to the `modelopt.torch.quantization.compress` documentation. +\[Coming Soon\] -To evaluate QLoRA quantized model before training, run: +## Native Fake-Quantized Evaluation -```sh -# Load the HF checkpoint, quantize the model and evaluate without additional training -# Also compress the model after quantization -./launch.sh --model meta-llama/Meta-Llama-3-8B \ - --do_train False \ - --quant_cfg NVFP4_DEFAULT_CFG \ - --compress True -``` - -To perform QLoRA training, run: - -```sh -# Load the HF checkpoint, quantize the model, add LoRA adapter, and run additional training -# Also compress the model after quantization -./launch.sh --model meta-llama/Meta-Llama-3-8B \ - --num_epochs 0.5 \ - --lr 1e-3 \ - --do_train True \ - --output_dir llama3-fp4-qlora \ - --quant_cfg NVFP4_DEFAULT_CFG \ - --compress True \ - --lora True -``` - -## QLoRA deployment - -After performing QLoRA training the final checkpoint can be exported for deployment with vLLM using the following command. +ModelOpt quantized models can be saved and restored without exporting to a deployment platform. This is useful for fast evaluation with fake quantization using standard LLM benchmarks (MMLU, WikiText, etc.). See [HuggingFace checkpointing](https://nvidia.github.io/Model-Optimizer/guides/2_save_load.html#modelopt-save-restore-using-huggingface-checkpointing-apis) for details. ```sh -python export.py \ - --pyt_ckpt_path llama3-fp4-qlora \ - --export_path llama3-fp4-qlora-hf \ +cd ../llm_eval +python lm_eval_hf.py --model hf \ + --tasks mmlu,wikitext \ + --model_args pretrained=../llm_qat/qwen3-8b-qat-nvfp4 \ + --batch_size 4 ``` -To deploy with vLLM, run the following command. For more details about QLoRA deployment using vLLM refer to the documentation [here](https://docs.vllm.ai/en/latest/features/lora.html). +See [llm_eval/README.md](../llm_eval/README.md) for supported tasks. -```sh -vllm serve llama3-fp4-qlora-hf/base_model --enable-lora --lora-modules adapter=llama3-fp4-qlora-hf --port 8000 --tokenizer llama3-fp4-qlora-hf -``` - -> _Note: We currently do not support export option for QLoRA models generated using FSDP2._ -> ## Pre-Quantized Checkpoints -- Ready-to-deploy checkpoints \[[🤗 Hugging Face - Nvidia Model Optimizer Collection](https://huggingface.co/collections/nvidia/inference-optimized-checkpoints-with-model-optimizer)\] +- Ready-to-deploy checkpoints: [Hugging Face - NVIDIA Model Optimizer Collection](https://huggingface.co/collections/nvidia/inference-optimized-checkpoints-with-model-optimizer) - Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) -- More models coming soon! ## Resources -- 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) -- 📖 [Documentation](https://nvidia.github.io/Model-Optimizer) -- 🎯 [Benchmarks](../benchmark.md) -- 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) -- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) -- ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) +- [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) +- [Documentation](https://nvidia.github.io/Model-Optimizer) +- [Benchmarks](../benchmark.md) +- [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) +- [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) +- [Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) diff --git a/examples/llm_qat/accelerate_config/fsdp1.yaml b/examples/llm_qat/accelerate_config/fsdp1.yaml deleted file mode 100644 index 5e0f5e652d8..00000000000 --- a/examples/llm_qat/accelerate_config/fsdp1.yaml +++ /dev/null @@ -1,29 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -enable_cpu_affinity: false -fsdp_config: - fsdp_activation_checkpointing: true - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_backward_prefetch: BACKWARD_PRE - fsdp_cpu_ram_efficient_loading: true - fsdp_forward_prefetch: false - fsdp_offload_params: false - fsdp_reshard_after_forward: FULL_SHARD - fsdp_state_dict_type: FULL_STATE_DICT - fsdp_sync_module_states: true - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer - fsdp_use_orig_params: true - fsdp_version: 1 -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: gpu -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/examples/llm_qat/arguments.py b/examples/llm_qat/arguments.py new file mode 100644 index 00000000000..1cd517482b7 --- /dev/null +++ b/examples/llm_qat/arguments.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared argument dataclasses for llm_qat scripts (quantize.py, train.py).""" + +from dataclasses import field + +import transformers + +from modelopt.torch.distill.plugins.huggingface import DistillArguments +from modelopt.torch.opt.plugins.transformers import ModelOptArgParser, ModelOptHFArguments +from modelopt.torch.quantization.plugins.transformers_trainer import ( + QuantizationArguments as ModelOptQuantizationArguments, +) + + +class ModelArguments(ModelOptHFArguments): + model_name_or_path: str = field( + default="Qwen/Qwen3-8B", + metadata={ + "help": "HuggingFace model name or local path to the base model to quantize/train." + }, + ) + model_max_length: int = field( + default=4096, + metadata={ + "help": ( + "Maximum sequence length. Sequences will be right-padded (and possibly truncated)." + ) + }, + ) + + +class DataArguments(ModelOptHFArguments): + dataset_config: str = field( + default="configs/dataset/blend.yaml", + metadata={"help": "Path to a dataset blend YAML config file."}, + ) + train_samples: int = field( + default=20000, + metadata={"help": "Number of training samples to use."}, + ) + eval_samples: int = field( + default=2000, + metadata={"help": "Number of evaluation samples to use."}, + ) + dataset_seed: int = field( + default=42, + metadata={"help": "Random seed for dataset shuffling."}, + ) + dataset_cache_dir: str = field( + default=".dataset_cache/tokenized", + metadata={"help": "Directory for caching tokenized datasets."}, + ) + shuffle: bool = field( + default=True, + metadata={"help": "Whether to shuffle dataset sources (reservoir sampling)."}, + ) + shuffle_buffer: int = field( + default=10000, + metadata={"help": "Buffer size for streaming shuffle."}, + ) + num_proc: int = field( + default=16, + metadata={"help": "Number of CPU workers for tokenization."}, + ) + + +class TrainingArguments(ModelOptHFArguments, transformers.TrainingArguments): + cache_dir: str | None = field(default=None) + dataloader_drop_last: bool = field(default=True) + bf16: bool = field(default=True) + lora: bool = field( + default=False, + metadata={ + "help": ( + "Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, " + "the LoRA adapter must be set, as quantized weights will be frozen during training." + ) + }, + ) + + +class QuantizeArguments(ModelOptQuantizationArguments): + calib_batch_size: int = field( + default=1, + metadata={"help": "Batch size for calibration data during quantization."}, + ) + output_dir: str = field( + default="quantized_model", + metadata={"help": "Directory to save the quantized model checkpoint."}, + ) + + +TRAINING_ARG_TYPES = (ModelArguments, TrainingArguments, DataArguments, DistillArguments) +QUANTIZE_ARG_TYPES = (ModelArguments, DataArguments, QuantizeArguments) + + +def _unique_arg_types(*arg_type_groups): + return tuple(dict.fromkeys(arg_type for group in arg_type_groups for arg_type in group)) + + +def get_training_arg_parser(): + return ModelOptArgParser(TRAINING_ARG_TYPES) + + +def get_quantize_arg_parser(): + return ModelOptArgParser(QUANTIZE_ARG_TYPES) + + +def get_docs_arg_parser(): + return ModelOptArgParser( + _unique_arg_types(TRAINING_ARG_TYPES, QUANTIZE_ARG_TYPES), + conflict_handler="resolve", + ) + + +def get_training_args(args=None): + return get_training_arg_parser().parse_args_into_dataclasses(args=args) + + +def get_quantize_args(args=None): + return get_quantize_arg_parser().parse_args_into_dataclasses(args=args) + + +if __name__ == "__main__": + get_docs_arg_parser().parse_args_into_dataclasses() diff --git a/examples/llm_qat/accelerate_config/ddp.yaml b/examples/llm_qat/configs/accelerate/ddp.yaml similarity index 100% rename from examples/llm_qat/accelerate_config/ddp.yaml rename to examples/llm_qat/configs/accelerate/ddp.yaml diff --git a/examples/llm_qat/accelerate_config/deepspeed.yaml b/examples/llm_qat/configs/accelerate/deepspeed.yaml similarity index 100% rename from examples/llm_qat/accelerate_config/deepspeed.yaml rename to examples/llm_qat/configs/accelerate/deepspeed.yaml diff --git a/examples/llm_qat/accelerate_config/fsdp2.yaml b/examples/llm_qat/configs/accelerate/fsdp2.yaml similarity index 91% rename from examples/llm_qat/accelerate_config/fsdp2.yaml rename to examples/llm_qat/configs/accelerate/fsdp2.yaml index 3c901d61760..d0004a58d04 100644 --- a/examples/llm_qat/accelerate_config/fsdp2.yaml +++ b/examples/llm_qat/configs/accelerate/fsdp2.yaml @@ -10,7 +10,7 @@ fsdp_config: fsdp_offload_params: false fsdp_reshard_after_forward: true fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer fsdp_version: 2 machine_rank: 0 main_training_function: main diff --git a/examples/llm_qat/configs/dataset/README.md b/examples/llm_qat/configs/dataset/README.md new file mode 100644 index 00000000000..0e336cecb60 --- /dev/null +++ b/examples/llm_qat/configs/dataset/README.md @@ -0,0 +1,144 @@ +# Dataset Blend Configuration + +Dataset blends are defined in YAML files that specify which datasets to mix, +how to sample from them, and how to tokenize them. + +See [`blend_example.yaml`](blend_example.yaml) for a runnable example with all options. + +## Blend YAML Structure + +Blend YAML files define the dataset size, split ratios, and sources: + +```yaml +blend_size: 100000 # total samples to download across all sources +splits: # train/eval/test split ratios (must sum to 1.0) + train: 0.80 + eval: 0.10 + test: 0.10 + +sources: + - hf_path: nvidia/Nemotron-SWE-v1 + split: r2e_gym + ratio: 6000 + category: code +``` + +Processing parameters (`cache_dir`, `shuffle`, `num_proc`, etc.) are set via +`DataArguments` in the training config YAML or CLI flags. `train_samples` and +`eval_samples` in training configs are runtime caps on the pre-split dataset — +changing them does not invalidate the cache. + +## Top-Level Fields + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `blend_size` | no | `100000` | Total samples to download across all sources | +| `splits` | no | `{train: 0.8, eval: 0.1, test: 0.1}` | Relative split weights; e.g. `{train: 8, eval: 1, test: 1}` gives 80/10/10 | + +## Per-Source Fields + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `hf_path` | yes | - | HuggingFace dataset path or local path | +| `ratio` | yes | - | Relative weight (normalized across all sources) | +| `split` | yes | - | Split(s) to load (auto train/eval). See below | +| `dataset_kwargs` | no | `{}` | Extra kwargs passed to `datasets.load_dataset()` (e.g. `{name: "3.0.0"}`) | +| `apply_chat_template` | no | `true` | If true, expects OpenAI messages format | +| `train_only_assistant_tokens` | no | `auto` | Label policy for chat datasets: `auto`, `true`, or `false`. See below | +| `chat_key` | no | `"messages"` | Key containing conversations | +| `category` | no | `""` | Label for logging | + +## Chat Label Masking + +`apply_chat_template` controls message formatting. `train_only_assistant_tokens` controls +which chat-template tokens become labels: + +- `auto` - Use assistant-only labels when the tokenizer supports native + `{% generation %}` masks or the tested Qwen/Nemotron ChatML heuristic; + otherwise train on all non-padding chat-template tokens with a warning. +- `true` - Require assistant-only labels; use native masks or the ChatML + heuristic, and fail if neither is available. +- `false` - Train on all non-padding chat-template tokens. + +## Split Modes + +Specifies which HuggingFace split(s) to load from each source. Samples are pooled across all sources, then globally split into train/eval/test by the top-level `splits` ratios. + +```yaml +# Single split +split: train + +# Comma-separated (equal weight per split) +split: code,math,stem + +# Dict (weighted per split: 3:2:1 ratio) +split: + code: 3 + math: 2 + stem: 1 +``` + +## Dataset Kwargs + +Pass any extra keyword arguments to `datasets.load_dataset()` via `dataset_kwargs`: + +```yaml +# HF config name (e.g. cnn_dailymail) +dataset_kwargs: {name: "3.0.0"} + +# Multiple kwargs +dataset_kwargs: + name: "3.0.0" + trust_remote_code: true + revision: main +``` + +## Streaming and Shuffle + +All HuggingFace datasets are loaded with `streaming=True` to avoid downloading +entire datasets. + +- `shuffle: true` - Reservoir sampling: `dataset.shuffle(buffer_size=N).take(n)`. + Accurate but slower with large buffers. +- `shuffle: false` - Take first N samples: `dataset.take(n)`. Fast and deterministic. + +## Pre-tokenize and Cache + +Pre-tokenize the dataset before training to avoid repeated work: + +```sh +python dataset_utils.py \ + --dataset_config configs/dataset/blend.yaml \ + --model_name_or_path Qwen/Qwen3-8B +``` + +The cached dataset is saved to `dataset_cache_dir` (default: `.dataset_cache/tokenized/`). +Subsequent runs with the same dataset config and tokenizer reuse the cache. +The cache key depends on `blend_size`, `splits`, sources, tokenization settings, +and the tokenizer — changing `train_samples` or `eval_samples` does **not** +invalidate the cache. + +## Adding New Datasets + +Add a source entry to your blend YAML: + +```yaml +sources: + # Chat dataset (OpenAI messages format) + - hf_path: your/dataset + split: train + ratio: 1000 + + # Dataset with different chat key + - hf_path: your/sharegpt-dataset + split: train + ratio: 500 + chat_key: conversations + train_only_assistant_tokens: false + + # Plain text dataset (pretraining-style) + - hf_path: your/text-corpus + split: train + ratio: 500 + apply_chat_template: false +``` diff --git a/examples/llm_qat/configs/dataset/blend.yaml b/examples/llm_qat/configs/dataset/blend.yaml new file mode 100644 index 00000000000..bc0f916d433 --- /dev/null +++ b/examples/llm_qat/configs/dataset/blend.yaml @@ -0,0 +1,45 @@ +# Default SFT blend for QAT/QAD training. +# Uses NVIDIA post-training datasets with chat template tokenization. +# Ratios are relative weights (normalized automatically). + +blend_size: 20000 +splits: + train: 0.90 + eval: 0.05 + test: 0.05 + +sources: + - hf_path: nvidia/Nemotron-SWE-v1 + split: r2e_gym + ratio: 6000 + category: code + + - hf_path: nvidia/Nemotron-Math-v2 + split: medium + ratio: 2500 + category: math + + - hf_path: nvidia/Nemotron-Science-v1 + split: MCQ + ratio: 1500 + category: science_qa + + - hf_path: nvidia/Nemotron-Science-v1 + split: RQA + ratio: 1500 + category: science_qa + + - hf_path: nvidia/Nemotron-Instruction-Following-Chat-v1 + split: chat_if + ratio: 5000 + category: chat + + - hf_path: nvidia/Nemotron-Post-Training-Dataset-v2 + split: chat + ratio: 1500 + category: multilingual + + - hf_path: nvidia/Nemotron-Competitive-Programming-v1 + split: competitive_coding_python_part00 + ratio: 1000 + category: swe_mixed diff --git a/examples/llm_qat/configs/dataset/blend_example.yaml b/examples/llm_qat/configs/dataset/blend_example.yaml new file mode 100644 index 00000000000..10ae209a60d --- /dev/null +++ b/examples/llm_qat/configs/dataset/blend_example.yaml @@ -0,0 +1,61 @@ +# Example dataset blend config demonstrating all mixing options. +# This file is a reference -- not used by any training config by default. +# See configs/dataset/README.md for full schema documentation. + +blend_size: 10000 +splits: + train: 0.80 + eval: 0.10 + test: 0.10 + +sources: + # --- Single split, auto train/eval --- + # Loads one split, shuffles, takes proportional samples, then train_test_split. + - hf_path: nvidia/Nemotron-SWE-v1 + split: r2e_gym + ratio: 3000 + category: code + + # --- Comma-separated splits (equal weight per split) --- + # Loads each split, takes equal samples from each, concatenates, then train_test_split. + - hf_path: nvidia/Nemotron-Post-Training-Dataset-v2 + split: code,math,stem + ratio: 2000 + category: mixed + + # --- Per-split ratios (dict form) --- + # Loads each split, takes weighted samples (3:2:1 here), concatenates, then train_test_split. + - hf_path: nvidia/Nemotron-Post-Training-Dataset-v2 + split: + code: 3 + math: 2 + stem: 1 + ratio: 2000 + category: mixed_weighted + + # --- Pretrain-style (plain text, no chat template) --- + # For datasets with a "text" field. All non-pad tokens become labels. + # Extra load_dataset kwargs (like HF config name) go in dataset_kwargs. + - hf_path: abisee/cnn_dailymail + dataset_kwargs: {name: "3.0.0"} + split: train + ratio: 500 + apply_chat_template: false + category: pretrain + + # --- Custom chat_key --- + # For datasets where conversations are under a different key than "messages". + - hf_path: Magpie-Align/Magpie-Pro-MT-300K-v0.1 + split: train + ratio: 500 + chat_key: conversations + # Train on every non-padding chat-template token instead of assistant-only tokens. + train_only_assistant_tokens: false + category: chat_alt + + # --- Local dataset path (load_from_disk) --- + # Uncomment to use a local pre-downloaded dataset. + # - hf_path: /local/path/to/dataset + # split: train + # ratio: 1000 + # apply_chat_template: false diff --git a/examples/llm_qat/configs/dataset/blend_test.yaml b/examples/llm_qat/configs/dataset/blend_test.yaml new file mode 100644 index 00000000000..cfb2edb81c2 --- /dev/null +++ b/examples/llm_qat/configs/dataset/blend_test.yaml @@ -0,0 +1,15 @@ +# Tiny fast blend for unit tests. +# Uses cnn_dailymail (small, fast to download) with no shuffle for speed. + +blend_size: 200 +splits: + train: 0.80 + eval: 0.10 + test: 0.10 + +sources: + - hf_path: abisee/cnn_dailymail + dataset_kwargs: {name: "3.0.0"} + split: train + ratio: 1 + apply_chat_template: false diff --git a/examples/llm_qat/configs/train/finetune.yaml b/examples/llm_qat/configs/train/finetune.yaml new file mode 100644 index 00000000000..d94ec359954 --- /dev/null +++ b/examples/llm_qat/configs/train/finetune.yaml @@ -0,0 +1,37 @@ +# Full-precision fine-tuning (no quantization) + +# Model +model_name_or_path: # e.g., Qwen/Qwen3-8B +output_dir: # e.g., qwen3-8b-finetune + +# Dataset +dataset_config: configs/dataset/blend.yaml +train_samples: 20000 +eval_samples: 2000 + +# Hyperparameters +num_train_epochs: 1.0 +learning_rate: 1e-5 +per_device_train_batch_size: 2 +per_device_eval_batch_size: 2 +gradient_accumulation_steps: 2 +model_max_length: 4096 +warmup_ratio: 0.05 +lr_scheduler_type: cosine +gradient_checkpointing: true +seed: 42 + +# Checkpointing +load_best_model_at_end: true +save_total_limit: 2 + +# Evaluation +do_eval: true +eval_on_start: true +eval_strategy: steps +eval_steps: 50 + +# Logging +logging_steps: 1 +report_to: + - tensorboard diff --git a/examples/llm_qat/configs/train/qad_nvfp4.yaml b/examples/llm_qat/configs/train/qad_nvfp4.yaml new file mode 100644 index 00000000000..89f61ff9d7e --- /dev/null +++ b/examples/llm_qat/configs/train/qad_nvfp4.yaml @@ -0,0 +1,42 @@ +# QAD: Quantization-Aware Distillation with NVFP4 + +# Model +model_name_or_path: # e.g., Qwen/Qwen3-8B +output_dir: # e.g., qwen3-8b-qad-nvfp4 + +# Distillation +distill: true +teacher_model: # e.g., Qwen/Qwen3-8B + +# Dataset +dataset_config: configs/dataset/blend.yaml +train_samples: 20000 +eval_samples: 2000 + +# Hyperparameters +num_train_epochs: 1.0 +learning_rate: 1e-5 +per_device_train_batch_size: 2 +per_device_eval_batch_size: 2 +gradient_accumulation_steps: 2 +model_max_length: 4096 +warmup_ratio: 0.05 +lr_scheduler_type: cosine +gradient_checkpointing: true +seed: 42 +do_train: true +do_eval: true + +# Checkpointing +load_best_model_at_end: true +save_total_limit: 2 + +# Evaluation +eval_on_start: true +eval_strategy: steps +eval_steps: 50 + +# Logging +logging_steps: 1 +report_to: + - tensorboard diff --git a/examples/llm_qat/configs/train/qat_nvfp4.yaml b/examples/llm_qat/configs/train/qat_nvfp4.yaml new file mode 100644 index 00000000000..44ad824fc7e --- /dev/null +++ b/examples/llm_qat/configs/train/qat_nvfp4.yaml @@ -0,0 +1,38 @@ +# QAT: Quantization-Aware Training with NVFP4 + +# Model +model_name_or_path: # e.g., Qwen/Qwen3-8B +output_dir: # e.g., qwen3-8b-qat-nvfp4 + +# Dataset +dataset_config: configs/dataset/blend.yaml +train_samples: 20000 +eval_samples: 2000 + +# Hyperparameters +num_train_epochs: 1.0 +learning_rate: 1e-5 +per_device_train_batch_size: 2 +per_device_eval_batch_size: 2 +gradient_accumulation_steps: 2 +model_max_length: 4096 +warmup_ratio: 0.05 +lr_scheduler_type: cosine +gradient_checkpointing: true +seed: 42 +do_train: true +do_eval: true + +# Checkpointing +load_best_model_at_end: true +save_total_limit: 2 + +# Evaluation +eval_on_start: true +eval_strategy: steps +eval_steps: 50 + +# Logging +logging_steps: 1 +report_to: + - tensorboard diff --git a/examples/llm_qat/configs/train/qlora_nvfp4.yaml b/examples/llm_qat/configs/train/qlora_nvfp4.yaml new file mode 100644 index 00000000000..489df04e166 --- /dev/null +++ b/examples/llm_qat/configs/train/qlora_nvfp4.yaml @@ -0,0 +1,41 @@ +# QLoRA: LoRA with real quantization (NVFP4) + +# Model +model_name_or_path: # e.g., Qwen/Qwen3-8B +output_dir: # e.g., qwen3-8b-qlora-nvfp4 + +# Dataset +dataset_config: configs/dataset/blend.yaml +train_samples: 20000 +eval_samples: 2000 + +# QLoRA +lora: true + +# Hyperparameters +num_train_epochs: 0.5 +learning_rate: 1e-3 +per_device_train_batch_size: 2 +per_device_eval_batch_size: 2 +gradient_accumulation_steps: 2 +model_max_length: 4096 +warmup_ratio: 0.05 +lr_scheduler_type: cosine +gradient_checkpointing: true +seed: 42 +do_train: true +do_eval: true + +# Checkpointing +load_best_model_at_end: true +save_total_limit: 2 + +# Evaluation +eval_on_start: true +eval_strategy: steps +eval_steps: 50 + +# Logging +logging_steps: 1 +report_to: + - tensorboard diff --git a/examples/llm_qat/dataset_utils.py b/examples/llm_qat/dataset_utils.py new file mode 100644 index 00000000000..fb0cc8b6907 --- /dev/null +++ b/examples/llm_qat/dataset_utils.py @@ -0,0 +1,831 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset blend utilities for QAT/QAD training. + +Provides YAML-driven dataset blending with: +- Multiple dataset sources with configurable ratios +- Chat tokenization via apply_chat_template with configurable label masking +- Pretrain tokenization for plain text datasets +- Distributed rank-aware loading and tokenization with disk caching +- Multi-process tokenization via ``num_proc`` (scales with local GPU count) +- Streaming dataset loading to avoid full downloads + +Usage as standalone CLI (pre-tokenize and cache): + + python dataset_utils.py \\ + --dataset_config configs/dataset/blend.yaml \\ + --model_name_or_path Qwen/Qwen3-1.7B + +Schema reference: See configs/dataset/README.md +""" + +from __future__ import annotations + +import concurrent.futures +import hashlib +import os +import re +import shutil +import tempfile +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import datasets +import yaml +from transformers.trainer_pt_utils import LabelSmoother + +from modelopt.torch.utils import print_rank_0, warn_rank_0 +from modelopt.torch.utils.distributed import DistributedProcessGroup +from modelopt.torch.utils.distributed import barrier as dist_barrier +from modelopt.torch.utils.distributed import rank as dist_rank +from modelopt.torch.utils.distributed import size as dist_size + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class DatasetSourceConfig: + """Configuration for a single dataset source in a blend. + + See configs/dataset/README.md for full schema. + """ + + hf_path: str + ratio: float + split: str | dict[str, float] = "" + dataset_kwargs: dict = field(default_factory=dict) + apply_chat_template: bool = True + train_only_assistant_tokens: bool | str = "auto" + chat_key: str = "messages" + category: str = "" + + def __post_init__(self): + if not self.split: + raise ValueError(f"{self.hf_path}: 'split' is required") + self.train_only_assistant_tokens = _normalize_train_only_assistant_tokens( + self.train_only_assistant_tokens + ) + + +@dataclass +class BlendConfig: + """Top-level configuration for a dataset blend. + + See configs/dataset/README.md for full schema. + """ + + sources: list[DatasetSourceConfig] = field(default_factory=list) + blend_size: int = 100000 + splits: dict[str, float] = field( + default_factory=lambda: {"train": 0.8, "eval": 0.1, "test": 0.1} + ) + + def __post_init__(self): + total = sum(self.splits.values()) + if total <= 0: + raise ValueError("Split ratios must sum to > 0") + self.splits = {k: v / total for k, v in self.splits.items()} + if self.blend_size <= 0: + raise ValueError(f"blend_size must be > 0, got {self.blend_size}") + + +@dataclass +class ParallelConfig: + """Parallelism strategy for dataset processing. + + Combines distributed rank-level sharding with intra-rank multi-process + tokenization via ``num_proc``. The ``effective_num_proc`` property auto-scales + workers per rank based on ``local_world_size`` to avoid CPU over-subscription. + """ + + num_proc: int = 16 + rank: int = 0 + world_size: int = 1 + + @property + def local_world_size(self) -> int: + """Ranks on this node (from ``LOCAL_WORLD_SIZE`` env var set by torchrun/SLURM).""" + lws = os.environ.get("LOCAL_WORLD_SIZE") + if lws: + return int(lws) + if self.is_distributed: + warn_rank_0( + f"LOCAL_WORLD_SIZE not set in distributed mode. " + f"Falling back to global world_size={self.world_size} (assumes single node)." + ) + return self.world_size + return 1 + + @property + def effective_num_proc(self) -> int | None: + """Workers per rank, scaled by local (per-node) rank count. + + Returns ``None`` when sequential processing is appropriate (``num_proc <= 1`` + after scaling), which tells HF ``datasets.map()`` to use the main process. + """ + lws = self.local_world_size + n = max(1, self.num_proc // lws) if lws > 1 else self.num_proc + return n if n > 1 else None + + @property + def is_distributed(self) -> bool: + return self.world_size > 1 + + +def load_blend_config(config_path: str) -> BlendConfig: + """Parse a dataset blend YAML file into a :class:`BlendConfig`.""" + with open(config_path) as f: + raw = yaml.safe_load(f) + + sources = [DatasetSourceConfig(**s) for s in raw.get("sources", [])] + kwargs: dict = {"sources": sources} + if "blend_size" in raw: + kwargs["blend_size"] = raw["blend_size"] + if "splits" in raw: + kwargs["splits"] = raw["splits"] + return BlendConfig(**kwargs) + + +def _normalize_ratios(sources: list[DatasetSourceConfig]) -> list[float]: + """Return normalized ratio weights summing to 1.0.""" + total = sum(s.ratio for s in sources) + if total <= 0: + raise ValueError("Sum of source ratios must be > 0") + return [s.ratio / total for s in sources] + + +def _supports_chatml_heuristic(tokenizer: PreTrainedTokenizerBase) -> bool: + """Check if tokenizer uses ChatML format (<|im_start|>/<|im_end|>).""" + try: + im_start = tokenizer.convert_tokens_to_ids("<|im_start|>") + im_end = tokenizer.convert_tokens_to_ids("<|im_end|>") + return tokenizer.unk_token_id not in (im_start, im_end) + except Exception: + return False + + +def _chat_template_has_generation(tokenizer: PreTrainedTokenizerBase) -> bool: + """Return True if the tokenizer's chat template declares ``{% generation %}``.""" + template = getattr(tokenizer, "chat_template", None) + if template is None: + return False + if isinstance(template, dict): + template = template.get("default") + if not isinstance(template, str): + return False + return bool(re.search(r"\{\%-?\s*generation\s*-?\%\}", template)) + + +def _encode_role(tokenizer: PreTrainedTokenizerBase, role: str) -> list[int]: + return tokenizer.encode(role, add_special_tokens=False) + + +def _matches_role(input_ids: list[int], start: int, role_ids: list[int]) -> bool: + end = start + len(role_ids) + return end <= len(input_ids) and input_ids[start:end] == role_ids + + +def _chatml_assistant_mask(input_ids: list[int], tokenizer: PreTrainedTokenizerBase) -> list[int]: + im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + assistant_ids = _encode_role(tokenizer, "assistant") + newline_id = tokenizer.encode("\n", add_special_tokens=False)[-1] + + # Intentionally excludes the trailing <|im_end|> and the post-header newline so loss is + # focused on assistant content tokens; a minor divergence from the native generation-tag mask. + masks = [0] * len(input_ids) + n_role = len(assistant_ids) + in_assistant = False + skip_remaining = 0 + skip_newline = False + + for i, tid in enumerate(input_ids): + if tid == im_start_id: + in_assistant = _matches_role(input_ids, i + 1, assistant_ids) + if in_assistant: + skip_remaining = n_role + skip_newline = False + continue + if tid == im_end_id: + in_assistant = False + continue + if in_assistant: + if skip_remaining > 0: + skip_remaining -= 1 + if skip_remaining == 0: + skip_newline = True + continue + if skip_newline and tid == newline_id: + skip_newline = False + continue + masks[i] = 1 + + return masks + + +def _normalize_train_only_assistant_tokens(value: bool | str) -> bool | str: + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized == "auto": + return normalized + if normalized == "true": + return True + if normalized == "false": + return False + raise ValueError("train_only_assistant_tokens must be one of: auto, true, false") + + +_TESTED_MODEL_FAMILIES = ( + "qwen", + "nemotron", +) + + +def _is_tested_model_family(tokenizer: PreTrainedTokenizerBase) -> bool: + model_name = getattr(tokenizer, "name_or_path", "") or "" + name_lower = model_name.lower() + return any(family in name_lower for family in _TESTED_MODEL_FAMILIES) + + +def make_chat_tokenize_fn( + tokenizer: PreTrainedTokenizerBase, + max_length: int, + chat_key: str = "messages", + train_only_assistant_tokens: bool | str = "auto", +): + """Create a tokenize function for chat datasets using ``apply_chat_template``. + + ``train_only_assistant_tokens`` controls label masking: + - ``"auto"`` uses assistant-only labels when supported, otherwise full labels. + - ``True`` requires supported assistant-only labels. + - ``False`` uses full labels after chat templating. + """ + train_only_assistant_tokens = _normalize_train_only_assistant_tokens( + train_only_assistant_tokens + ) + model_name = getattr(tokenizer, "name_or_path", "unknown") + mask_mode = None + if train_only_assistant_tokens: + supports_chatml = _supports_chatml_heuristic(tokenizer) + is_tested_family = _is_tested_model_family(tokenizer) + if _chat_template_has_generation(tokenizer): + mask_mode = "native" + elif supports_chatml and (is_tested_family or train_only_assistant_tokens is True): + if not is_tested_family: + warn_rank_0( + f"Model '{model_name}' is not from a tested model family " + f"({', '.join(_TESTED_MODEL_FAMILIES)}). " + "Please verify masked tokens manually." + ) + mask_mode = "chatml" + warn_rank_0( + "Chat template lacks {% generation %} support. " + "Using heuristic ChatML-based assistant masking." + ) + elif train_only_assistant_tokens is True: + raise ValueError( + f"Chat template for '{model_name}' does not support " + f"{{% generation %}} and does not use ChatML format. " + f"Set train_only_assistant_tokens: false to train on all chat-template tokens." + ) + else: + warn_rank_0( + f"Assistant token masking is not supported or tested for '{model_name}'. " + "Training on all non-padding chat-template tokens." + ) + + def tokenize(sample): + messages = sample.get(chat_key) + if not messages: + pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 + return { + "input_ids": [pad_id] * max_length, + "attention_mask": [0] * max_length, + "labels": [IGNORE_TOKEN_ID] * max_length, + } + + try: + result = tokenizer.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_assistant_tokens_mask=mask_mode == "native", + padding="max_length", + truncation=True, + max_length=max_length, + ) + except (ValueError, TypeError, KeyError) as e: + warn_rank_0(f"Failed to tokenize sample: {e}. Skipping.") + pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 + return { + "input_ids": [pad_id] * max_length, + "attention_mask": [0] * max_length, + "labels": [IGNORE_TOKEN_ID] * max_length, + } + + input_ids = result["input_ids"] + if mask_mode == "native": + label_mask = result["assistant_masks"] + elif mask_mode == "chatml": + label_mask = _chatml_assistant_mask(input_ids, tokenizer) + else: + label_mask = result["attention_mask"] + + labels = [tid if mask else IGNORE_TOKEN_ID for tid, mask in zip(input_ids, label_mask)] + + return { + "input_ids": input_ids, + "attention_mask": result["attention_mask"], + "labels": labels, + } + + return tokenize + + +def make_pretrain_tokenize_fn( + tokenizer: PreTrainedTokenizerBase, + max_length: int, +): + """Create a tokenize function for plain text (pretraining-style). + + All non-padding tokens are trainable (labels = input_ids). + """ + + def tokenize(sample): + text = sample.get("text", "") + if not text: + text = sample.get("article", "") or sample.get("content", "") + + input_ids = tokenizer.encode(text, add_special_tokens=True)[:max_length] + cur_len = len(input_ids) + + pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id + if pad_token is None: + raise ValueError("Tokenizer must provide either pad_token_id or eos_token_id") + + attention_mask = [1] * cur_len + [0] * (max_length - cur_len) + labels = list(input_ids) + [IGNORE_TOKEN_ID] * (max_length - cur_len) + input_ids = list(input_ids) + [pad_token] * (max_length - cur_len) + + return { + "input_ids": input_ids[:max_length], + "attention_mask": attention_mask[:max_length], + "labels": labels[:max_length], + } + + return tokenize + + +def _parse_split_spec(split_spec: str | dict[str, float]) -> dict[str, float]: + """Parse a split specification into {split_name: weight} dict. + + Examples: + "train" -> {"train": 1.0} + "code,math,stem" -> {"code": 1.0, "math": 1.0, "stem": 1.0} + {code: 3, math: 2} -> {"code": 3.0, "math": 2.0} + """ + if isinstance(split_spec, dict): + return {k: float(v) for k, v in split_spec.items()} + parts = [p.strip() for p in str(split_spec).split(",") if p.strip()] + return dict.fromkeys(parts, 1.0) + + +def _stream_samples( + hf_path: str, + split_name: str, + num_samples: int, + shuffle: bool, + shuffle_buffer: int = 10000, + seed: int = 42, + rank: int = 0, + world_size: int = 1, + dataset_kwargs: dict | None = None, +) -> list[dict]: + """Stream this rank's portion of ``num_samples`` from a single split. + + When ``world_size > 1``, each rank loads only its shard of the data: + - Local datasets: O(1) random access via ``select()`` + - Streaming datasets: ``skip(offset).take(per_rank)`` (skips without storing) + """ + per_rank = num_samples // world_size + offset = rank * per_rank + if rank == world_size - 1: + per_rank = num_samples - offset # last rank gets remainder + + is_local = os.path.exists(hf_path) + t0 = time.time() + + if is_local: + print_rank_0(f"\tLoading local dataset {hf_path}...") + try: + ds = datasets.load_from_disk(hf_path) + if isinstance(ds, datasets.DatasetDict): + ds = ds[split_name] + if shuffle: + ds = ds.shuffle(seed=seed) + end = min(offset + per_rank, len(ds)) + result = list(ds.select(range(offset, end))) + print_rank_0(f"\tFetched {len(result)} samples in {time.time() - t0:.1f}s") + return result + except Exception as e: + warn_rank_0(f"Failed to load {hf_path} [{split_name}]: {e}. Skipping this split.") + return [] + + print_rank_0(f"\tStreaming {hf_path} [{split_name}]...") + load_kwargs: dict = {"split": split_name, "streaming": True} + load_kwargs.update(dataset_kwargs or {}) + print_rank_0(f"\tFetching {per_rank} samples (rank {rank})...") + try: + ds = datasets.load_dataset(hf_path, **load_kwargs) + if shuffle: + ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer) + result = list(ds.skip(offset).take(per_rank)) + except Exception as e: + warn_rank_0(f"Failed to stream {hf_path} [{split_name}]: {e}. Skipping this split.") + return [] + print_rank_0(f"\tFetched {len(result)} samples in {time.time() - t0:.1f}s") + return result + + +def _load_source_samples( + source: DatasetSourceConfig, + num_samples: int, + shuffle: bool, + shuffle_buffer: int, + seed: int = 42, + rank: int = 0, + world_size: int = 1, +) -> list[dict]: + """Load raw samples from a single source (all splits combined), rank-aware.""" + split_weights = _parse_split_spec(source.split) + total_weight = sum(split_weights.values()) + + all_samples = [] + remaining = num_samples + split_items = list(split_weights.items()) + + for i, (split_name, weight) in enumerate(split_items): + if i == len(split_items) - 1: + n = remaining # last split gets the remainder + else: + n = max(1, round(weight / total_weight * num_samples)) + remaining -= n + + samples = _stream_samples( + source.hf_path, + split_name, + n, + shuffle, + shuffle_buffer, + seed, + rank, + world_size, + dataset_kwargs=source.dataset_kwargs, + ) + print_rank_0( + f" {source.hf_path} [{split_name}]: requested {n}, got {len(samples)}" + f" (rank {rank}/{world_size})" + ) + all_samples.extend(samples) + + return all_samples + + +_dataset_cache: dict[str, datasets.DatasetDict] = {} + + +def _tokenizer_fingerprint(tokenizer: PreTrainedTokenizerBase) -> tuple[str, str]: + """Return ``(short_name, fingerprint)`` for cache key construction. + + The fingerprint captures class name, vocab size, and special token IDs so that + tokenizers of the same class but different vocabularies produce distinct caches. + """ + cls_name = type(tokenizer).__name__ + parts = [ + cls_name, + f"vocab={tokenizer.vocab_size}", + f"eos={tokenizer.eos_token_id}", + f"bos={getattr(tokenizer, 'bos_token_id', None)}", + f"pad={tokenizer.pad_token_id}", + f"unk={getattr(tokenizer, 'unk_token_id', None)}", + ] + return cls_name, "|".join(parts) + + +def _build_cache_path( + config: BlendConfig, + tokenizer: PreTrainedTokenizerBase, + max_length: int, + cache_dir: str, +) -> str: + """Build a deterministic cache path for the blend config.""" + base = cache_dir if cache_dir else tempfile.gettempdir() + + tok_name, tok_fp = _tokenizer_fingerprint(tokenizer) + splits_str = ",".join(f"{k}:{v}" for k, v in sorted(config.splits.items())) + sig = f"{tok_fp}|{max_length}|{config.blend_size}|{splits_str}" + for s in config.sources: + sig += ( + f"|{s.hf_path}|{s.ratio}|{s.split}|{s.dataset_kwargs}" + f"|chat={s.apply_chat_template}|chat_key={s.chat_key}" + f"|train_only_assistant_tokens={s.train_only_assistant_tokens}" + ) + cache_key = hashlib.sha1(sig.encode()).hexdigest()[:12] + + return os.path.join( + base, + f"llm_qat_{tok_name}_blend{config.blend_size}_{cache_key}", + ) + + +def _is_non_empty_dir(path: str) -> bool: + return os.path.isdir(path) and bool(os.listdir(path)) + + +def _load_cached_dataset(cache_path: str) -> datasets.DatasetDict | None: + """Try to load from in-memory or disk cache. Returns ``None`` if not cached.""" + if cache_path in _dataset_cache: + print_rank_0(f"Using in-memory cached dataset: {cache_path}") + return _dataset_cache[cache_path] + + if _is_non_empty_dir(cache_path): + if os.path.exists(os.path.join(cache_path, "dataset_dict.json")): + print_rank_0(f"Using disk-cached dataset: {cache_path}") + _dataset_cache[cache_path] = datasets.load_from_disk(cache_path) + return _dataset_cache[cache_path] + + return None + + +_EMPTY_TOKENIZED = {"input_ids": [], "attention_mask": [], "labels": []} + + +def _concat_parts(parts: list[datasets.Dataset]) -> datasets.Dataset: + """Concatenate non-empty dataset parts, returning an empty dataset if all are empty.""" + non_empty = [p for p in parts if len(p) > 0] + if not non_empty: + return datasets.Dataset.from_dict(_EMPTY_TOKENIZED) + if len(non_empty) == 1: + return non_empty[0] + return datasets.concatenate_datasets(non_empty) + + +def _load_all_source_samples( + config: BlendConfig, + norm_ratios: list[float], + parallel: ParallelConfig, + shuffle: bool, + shuffle_buffer: int, + seed: int, +) -> tuple[list[list[dict]], list[int]]: + """Load raw samples from all sources for this rank (flat, no split). + + Returns: + (per_source_samples, per_source_counts) where + ``per_source_samples[i]`` is the list of raw dicts for source *i* + and ``per_source_counts[i] = len(per_source_samples[i])``. + """ + per_source_samples: list[list[dict]] = [] + per_source_counts: list[int] = [] + + print_rank_0(f"Loading {len(config.sources)} sources into blend...") + + num_sources = len(config.sources) + for idx, (source, norm_ratio) in enumerate(zip(config.sources, norm_ratios), 1): + source_total = max(1, round(norm_ratio * config.blend_size)) + + cat_label = f" [{source.category}]" if source.category else "" + print_rank_0( + f"Source [{idx}/{num_sources}]: {source.hf_path}{cat_label}" + f" (ratio={norm_ratio:.3f}, n={source_total})" + ) + + samples = _load_source_samples( + source, + source_total, + shuffle, + shuffle_buffer, + seed, + parallel.rank, + parallel.world_size, + ) + per_source_samples.append(samples) + per_source_counts.append(len(samples)) + + local_total = sum(per_source_counts) + group = DistributedProcessGroup(group=None) + global_total = DistributedProcessGroup.get_dist_syncd_obj( + local_total, group, op=lambda objs: sum(objs) + ) + print_rank_0(f"Total raw samples across all ranks: {global_total}") + return per_source_samples, per_source_counts + + +def _tokenize_source_split( + source: DatasetSourceConfig, + raw_samples: list[dict], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + parallel: ParallelConfig, +) -> datasets.Dataset: + """Tokenize raw samples for a single source and split. + + Data is already rank-specific (loaded by ``_load_all_source_samples``), + so no sharding is needed here. Uses ``parallel.effective_num_proc`` for + multi-process tokenization. + """ + if source.apply_chat_template: + tokenize_fn = make_chat_tokenize_fn( + tokenizer, + max_length, + chat_key=source.chat_key, + train_only_assistant_tokens=source.train_only_assistant_tokens, + ) + else: + tokenize_fn = make_pretrain_tokenize_fn(tokenizer, max_length) + + ds = datasets.Dataset.from_list(raw_samples) + if len(ds) == 0: + return datasets.Dataset.from_dict(_EMPTY_TOKENIZED) + + print_rank_0( + f"\tTokenizing {len(raw_samples)} samples (num_proc={parallel.effective_num_proc})..." + ) + tokenized = ds.map( + tokenize_fn, + remove_columns=list(ds.features), + num_proc=parallel.effective_num_proc, + desc=f"Tokenizing {source.hf_path} rank {parallel.rank}/{parallel.world_size}", + ) + before = len(tokenized) + tokenized = tokenized.filter( + lambda x: any(label != IGNORE_TOKEN_ID for label in x["labels"]), + num_proc=parallel.effective_num_proc, + ) + dropped = before - len(tokenized) + if dropped: + warn_rank_0( + f"Dropped {dropped}/{before} samples with no valid labels " + f"from {source.hf_path} (all labels are IGNORE_INDEX after tokenization)." + ) + return tokenized + + +def _merge_distributed_shards( + cache_path: str, + local_flat: datasets.Dataset, + parallel: ParallelConfig, + splits: dict[str, float], + seed: int = 42, +) -> datasets.DatasetDict: + """Save per-rank flat data, merge on rank 0, shuffle, and split by ratios. + + Each rank saves its local tokenized data as a flat Dataset. Rank 0 loads + all shards, concatenates, shuffles deterministically, then splits by the + configured ratios. + """ + print_rank_0(f"\tSaving rank {parallel.rank} data to disk...") + temp_dir = os.path.join(cache_path, "temp") + rank_path = os.path.join(temp_dir, f"rank_{parallel.rank}") + os.makedirs(rank_path, exist_ok=True) + local_flat.save_to_disk(rank_path) + + dist_barrier() + + if parallel.rank == 0: + + def load_rank(r: int) -> datasets.Dataset: + return datasets.load_from_disk(os.path.join(temp_dir, f"rank_{r}")) + + print_rank_0(f"\tMerging {parallel.world_size} shards...") + with concurrent.futures.ThreadPoolExecutor(max_workers=min(4, parallel.world_size)) as pool: + all_shards = list(pool.map(load_rank, range(parallel.world_size))) + + merged = _concat_parts(all_shards) + merged = merged.shuffle(seed=seed) + + total = len(merged) + split_datasets = {} + offset = 0 + split_items = list(splits.items()) + for i, (split_name, ratio) in enumerate(split_items): + if i == len(split_items) - 1: + count = total - offset + else: + count = round(ratio * total) + split_datasets[split_name] = merged.select(range(offset, offset + count)) + offset += count + + result = datasets.DatasetDict(split_datasets) + result.save_to_disk(cache_path) + + shutil.rmtree(temp_dir, ignore_errors=True) + split_summary = ", ".join(f"{k}={len(v)}" for k, v in result.items()) + print_rank_0(f"Cached blended dataset to {cache_path} ({split_summary})") + + dist_barrier() + + return datasets.load_from_disk(cache_path) + + +def build_blend_dataset( + config: BlendConfig, + tokenizer: PreTrainedTokenizerBase, + max_length: int, + seed: int = 42, + cache_dir: str = ".dataset_cache/tokenized", + shuffle: bool = True, + shuffle_buffer: int = 10000, + num_proc: int = 16, +) -> datasets.DatasetDict: + """Build a blended, tokenized dataset from a :class:`BlendConfig`. + + Returns a ``DatasetDict`` with keys matching ``config.splits`` + (e.g. ``"train"``, ``"eval"``, ``"test"``). + """ + cache_path = _build_cache_path(config, tokenizer, max_length, cache_dir) + + cached = _load_cached_dataset(cache_path) + if cached is not None: + return cached + + rank, world_size = dist_rank(), dist_size() + parallel = ParallelConfig(num_proc=num_proc, rank=rank, world_size=world_size) + + if rank == 0: + os.makedirs(cache_path, exist_ok=True) + dist_barrier() + + norm_ratios = _normalize_ratios(config.sources) + per_source_samples, per_source_counts = _load_all_source_samples( + config, norm_ratios, parallel, shuffle, shuffle_buffer, seed + ) + + print_rank_0(f"Tokenizing {len(config.sources)} sources...") + tokenized_parts: list[datasets.Dataset] = [] + for source, samples in zip(config.sources, per_source_samples): + if samples: + tokenized_parts.append( + _tokenize_source_split(source, samples, tokenizer, max_length, parallel) + ) + + local_flat = _concat_parts(tokenized_parts) + + print_rank_0("Merging distributed shards...") + result = _merge_distributed_shards(cache_path, local_flat, parallel, config.splits, seed) + _dataset_cache[cache_path] = result + return result + + +def main(): + import transformers + from arguments import DataArguments, ModelArguments + + from modelopt.torch.opt.plugins.transformers import ModelOptArgParser + + parser = ModelOptArgParser((ModelArguments, DataArguments)) + model_args, data_args = parser.parse_args_into_dataclasses() + + config = load_blend_config(data_args.dataset_config) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, model_max_length=model_args.model_max_length + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + ds = build_blend_dataset( + config, + tokenizer, + model_args.model_max_length, + seed=data_args.dataset_seed, + cache_dir=data_args.dataset_cache_dir, + shuffle=data_args.shuffle, + shuffle_buffer=data_args.shuffle_buffer, + num_proc=data_args.num_proc, + ) + split_summary = ", ".join(f"{k}: {len(v)}" for k, v in ds.items()) + print(f"Built dataset: {split_summary}") + + +if __name__ == "__main__": + main() diff --git a/examples/llm_qat/launch.sh b/examples/llm_qat/launch.sh deleted file mode 100755 index cc3adc74fe3..00000000000 --- a/examples/llm_qat/launch.sh +++ /dev/null @@ -1,179 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -eo pipefail - -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - -# Helper function to parse a single argument value -parse_value() { - if [[ "$1" != *=* ]]; then shift; fi - echo "${1#*=}" -} - -while [ $# -gt 0 ]; do - case "$1" in - --model*) MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --output_dir*) OUTPUT_DIR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --dataset*) DATASET=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --train_size*) TRAIN_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --eval_size*) EVAL_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --num_epochs*) NUM_EPOCHS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --max_steps*) MAX_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --save_steps*) SAVE_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --accum_steps*) ACCUM_STEPS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --lr*) LR=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --quant_cfg*) QUANT_CFG=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --compress*) COMPRESS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --calib_size*) CALIB_SIZE=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --train_bs*) TRAIN_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --eval_bs*) EVAL_BS=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --do_train*) DO_TRAIN=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --lora*) LORA=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --teacher_model*) TEACHER_MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --distill*) DISTILL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --fsdp_transformer_layer_cls_to_wrap*) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --max_seq_length*) MAX_SEQ_LENGTH=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - --backend*) BACKEND=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;; - *) - >&2 printf "Error: Invalid argument ${1#*=}\n" - exit 1 - ;; - esac - shift -done - -set -x - -# Get the default value for save_steps based on the available number of GPUs -GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) - -MODEL=${MODEL:-"meta-llama/Llama-2-7b-hf"} -OUTPUT_DIR=${OUTPUT_DIR:-"llama2-finetune"} -DATASET=${DATASET:-"Daring-Anteater"} -MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-4096} -TRAIN_SIZE=${TRAIN_SIZE:-0} -EVAL_SIZE=${EVAL_SIZE:-0} -NUM_EPOCHS=${NUM_EPOCHS:-1} -SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} -ACCUM_STEPS=${ACCUM_STEPS:-1} -LR=${LR:-"1e-4"} -CALIB_SIZE=${CALIB_SIZE:-512} -TRAIN_BS=${TRAIN_BS:-4} -EVAL_BS=${EVAL_BS:-4} -DO_TRAIN=${DO_TRAIN:-True} -LORA=${LORA:-"False"} -COMPRESS=${COMPRESS:-"False"} -DISTILL=${DISTILL:-"False"} -TEACHER_MODEL=${TEACHER_MODEL:-$MODEL} -FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} -BACKEND=${BACKEND:-"fsdp2"} - -if [ -z $QUANT_CFG ]; then - QUANT_ARGS="" -else - QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE" -fi - -OPTIONAL_ARGS="" -if [ ! -z $MAX_STEPS ]; then - OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS" -fi - -# if compress is true, set backend to ddp -if [[ "${COMPRESS,,}" == "true" ]]; then - BACKEND="ddp" -fi - -# Configure backend-specific settings -GRADIENT_CHECKPOINTING_ARGS="" -case "${BACKEND,,}" in - "fsdp1"|"fsdp") - CONFIG_FILE="fsdp1.yaml" - FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" - ;; - "fsdp2") - echo "Using FSDP2 instead of FSDP1." - CONFIG_FILE="fsdp2.yaml" - FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" - ;; - "ddp") - CONFIG_FILE="ddp.yaml" - FSDP_ARGS="" - GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True" - ;; - "deepspeed") - CONFIG_FILE="deepspeed.yaml" - FSDP_ARGS="" - GRADIENT_CHECKPOINTING_ARGS="--gradient_checkpointing True" - ;; - *) - echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed" - exit 1 - ;; -esac - -# TODO: Remove this after simple distillation is supported -DISTILLATION_ARGS="" -if [[ "${DISTILL,,}" == "true" ]]; then - DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL" - if [[ "${BACKEND,,}" == "fsdp1" ]]; then - echo "Error: Distillation does not support FSDP1. Use FSDP2 instead." - exit 1 - elif [[ "${BACKEND,,}" == "fsdp2" ]]; then - # Distillation does not work with memory efficient loading for FSDP - FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False" - fi -fi - -CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \ - main.py \ - --model_name_or_path $MODEL \ - --model_max_length $MAX_SEQ_LENGTH \ - --dataloader_drop_last True \ - --do_train $DO_TRAIN \ - --do_eval True \ - --output_dir $OUTPUT_DIR \ - --dataset $DATASET \ - --train_size $TRAIN_SIZE \ - --eval_size $EVAL_SIZE \ - --num_train_epochs $NUM_EPOCHS \ - --per_device_train_batch_size $TRAIN_BS \ - --per_device_eval_batch_size $EVAL_BS \ - --gradient_accumulation_steps $ACCUM_STEPS \ - --eval_accumulation_steps 1 \ - --save_strategy steps \ - --save_steps $SAVE_STEPS \ - --eval_strategy steps \ - --eval_steps $SAVE_STEPS \ - --load_best_model_at_end True \ - --save_total_limit 2 \ - --learning_rate $LR \ - --weight_decay 0.0 \ - --warmup_steps 0.1 \ - --lr_scheduler_type linear \ - --logging_steps 1 \ - --report_to tensorboard \ - --lora $LORA \ - --compress $COMPRESS \ - $GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS -" - -start_time=$(date +%s) -sh -c "$CMD" -echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" diff --git a/examples/llm_qat/llama_factory/README.md b/examples/llm_qat/llama_factory/README.md index b72c3dbfdfe..efa511c16c2 100644 --- a/examples/llm_qat/llama_factory/README.md +++ b/examples/llm_qat/llama_factory/README.md @@ -4,7 +4,7 @@ This directory provides integration between [LLaMA-Factory](https://github.com/h ## Quick Start -### Basic QAT/QAD Training with FSDP +### Basic QAT/QAD Training ```bash ./launch_llamafactory.sh llama_config.yaml @@ -12,24 +12,18 @@ This directory provides integration between [LLaMA-Factory](https://github.com/h > **_NOTE:_** The `launch_llamafactory.sh` script automatically installs LLaMA Factory if it's not already present in your environment. -In order to train using FSDP2: - -```sh -./launch_llamafactory.sh llama_config.yaml --use_fsdp2 true -``` - -By default, the script uses [fsdp1.yaml](../accelerate_config/fsdp1.yaml) and [fsdp2.yaml](../accelerate_config/fsdp2.yaml) for FSDP and FSDP2 training respectively. +By default, the script uses [fsdp2.yaml](../configs/accelerate/fsdp2.yaml) for distributed training. **Use Custom FSDP Arguments**: Pass additional FSDP parameters using the `FSDP_ARGS` environment variable: ```bash -FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer" \ +FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer" \ ./launch_llamafactory.sh llama_config.yaml ``` -> **_NOTE:_** The default `fsdp*.yml` files use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different layer class, you can either pass `--fsdp_transformer_layer_cls_to_wrap ` to the `launch_llamafactory.sh` script or provide a custom FSDP configuration file. +> **_NOTE:_** The default `fsdp2.yaml` uses `Qwen3DecoderLayer` as the transformer layer class. If your model uses a different layer class (e.g., `LlamaDecoderLayer` for Llama models), pass `--fsdp_transformer_layer_cls_to_wrap ` to the `launch_llamafactory.sh` script or provide a custom FSDP configuration file. See [Supported Backends](../README.md#supported-backends) for details. **Custom Config File**: @@ -83,14 +77,14 @@ val_size: 0.1 # Validation split ratio ### ModelOpt Configuration modelopt: - quant_cfg: NVFP4_DEFAULT_CFG # Quantization format - calib_size: 1024 # Calibration dataset size - compress: false # Enable weight compression - distill: false # Modify distill to true for QAD - teacher_model: /path/to/teacher/model # For QAD (optional) + recipe: general/ptq/nvfp4_default-kv_fp8 # Quantization recipe (built-in or custom path) + calib_size: 1024 # Calibration dataset size + compress: false # Enable weight compression + distill: false # Modify distill to true for QAD + teacher_model: /path/to/teacher/model # For QAD (optional) ``` -> **_NOTE:_** `compress: true` enables weight compression and will by default use [ddp.yaml](../accelerate_config/ddp.yaml). +> **_NOTE:_** `compress: true` enables weight compression and will by default use [ddp.yaml](../configs/accelerate/ddp.yaml). > **_NOTE:_** When training without [cli](#training-using-cli), avoid using deepspeed option in the YAML configuration file. ## Deployment diff --git a/examples/llm_qat/llama_factory/launch_llamafactory.sh b/examples/llm_qat/llama_factory/launch_llamafactory.sh index 03effc89aeb..d90478809ef 100644 --- a/examples/llm_qat/llama_factory/launch_llamafactory.sh +++ b/examples/llm_qat/llama_factory/launch_llamafactory.sh @@ -140,8 +140,8 @@ if [[ $1 == "help" ]] || [[ $1 == "-h" ]]; then echo "Arguments:" echo " YAML config file for llama_factory" echo " --accelerate_config Accelerate config file (optional)" - echo " --use_fsdp2 Use FSDP2 instead of FSDP1 (default: false)" - echo " $0 llama_config.yaml --accelerate_config ../accelerate_config/fsdp2.yaml" + echo " --use_fsdp2 Use FSDP2 (default: true)" + echo " $0 llama_config.yaml --accelerate_config ../configs/accelerate/fsdp2.yaml" echo "" echo "or" echo "" @@ -184,7 +184,7 @@ else # Move to next argument shift ACCELERATE_CONFIG="" - USE_FSDP2="false" + USE_FSDP2="true" while [ $# -gt 0 ]; do case "$1" in @@ -238,20 +238,14 @@ else # Set default accelerate config if not provided if [[ -z "$ACCELERATE_CONFIG" ]]; then if check_compress_enabled "$CONFIG_FILE"; then - ACCELERATE_CONFIG="$SCRIPT_DIR/../accelerate_config/ddp.yaml" - elif [[ "${USE_FSDP2,,}" == "true" ]]; then - ACCELERATE_CONFIG="$SCRIPT_DIR/../accelerate_config/fsdp2.yaml" + ACCELERATE_CONFIG="$SCRIPT_DIR/../configs/accelerate/ddp.yaml" else - ACCELERATE_CONFIG="$SCRIPT_DIR/../accelerate_config/fsdp1.yaml" + ACCELERATE_CONFIG="$SCRIPT_DIR/../configs/accelerate/fsdp2.yaml" fi fi # Add teacher model specific FSDP args if needed if [[ "${HAS_TEACHER_MODEL,,}" == "true" ]]; then - if [[ "${USE_FSDP2,,}" != "true" ]]; then - echo "Error: Quantization aware distillation is only supported with FSDP2." - exit 1 - fi FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False" fi diff --git a/examples/llm_qat/llama_factory/llama_config.yaml b/examples/llm_qat/llama_factory/llama_config.yaml index 563e3790138..6acbda40a8d 100644 --- a/examples/llm_qat/llama_factory/llama_config.yaml +++ b/examples/llm_qat/llama_factory/llama_config.yaml @@ -45,7 +45,7 @@ eval_strategy: steps eval_steps: 16 modelopt: - quant_cfg: NVFP4_DEFAULT_CFG + recipe: general/ptq/nvfp4_default-kv_fp8 calib_size: 1024 compress: false distill: false diff --git a/examples/llm_qat/llama_factory/llama_factory.py b/examples/llm_qat/llama_factory/llama_factory.py index 121c3b0f394..50fa2d7fc86 100644 --- a/examples/llm_qat/llama_factory/llama_factory.py +++ b/examples/llm_qat/llama_factory/llama_factory.py @@ -43,7 +43,7 @@ @dataclass class QuantizationArguments: - quant_cfg: str | None = None + recipe: str | None = None calib_size: int = 512 compress: bool = False @@ -192,8 +192,8 @@ def create_patch_module(quant_args=None, distill_args=None): into LLaMA-Factory's training pipeline without modifying the original code. Args: - quant_args: SimpleNamespace containing quantization parameters - distill_args: SimpleNamespace containing distillation parameters + quant_args: QuantizationArguments containing quantization parameters + distill_args: DistillationArguments containing distillation parameters Returns: function: Patch function that modifies the trainer class diff --git a/examples/llm_qat/main.py b/examples/llm_qat/main.py deleted file mode 100644 index 2d715881b61..00000000000 --- a/examples/llm_qat/main.py +++ /dev/null @@ -1,272 +0,0 @@ -# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py - -# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from dataclasses import dataclass, field -from warnings import warn - -import torch -import transformers -from transformers.trainer_utils import get_last_checkpoint -from utils import ( - get_lora_config, - get_metrics_with_perplexity, - make_supervised_data_module, - monkey_patch_training_step_to_fix_memory_leak, -) - -import modelopt.torch.opt as mto -import modelopt.torch.quantization as mtq -from modelopt.torch.distill.plugins.huggingface import LMLogitsLoss -from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer, QATTrainer -from modelopt.torch.utils import print_rank_0 - -# Enable automatic save/load of modelopt state huggingface checkpointing -mto.enable_huggingface_checkpointing() - -CUSTOM_QUANT_CFG = { - "INT4_WEIGHT_INT8_ACTIVATIONS": { - "quant_cfg": [ - {"quantizer_name": "*", "enable": False}, - { - "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": 4, "block_sizes": {-1: 128}}, - "enable": True, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": 8, "axis": None}, - "enable": True, - }, - {"quantizer_name": "*lm_head*", "enable": False}, - ], - "algorithm": "max", - } -} - - -@dataclass -class ModelArguments: - model_name_or_path: str = field(default="meta-llama/Llama-2-7b-hf") - teacher_model: str | None = field( - default=None, - metadata={"help": ("The name or path of the teacher model to use for distillation.")}, - ) - - -@dataclass -class TrainingArguments(transformers.TrainingArguments): - cache_dir: str | None = field(default=None) - model_max_length: int = field( - default=2048, - metadata={ - "help": ( - "Maximum sequence length. Sequences will be right padded (and possibly truncated)." - ) - }, - ) - dataloader_drop_last: bool = field(default=True) - bf16: bool = field(default=True) - lora: bool = field( - default=False, - metadata={ - "help": ( - "Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, " - "the LoRA adapter must be set, as quantized weights will be frozen during training." - ) - }, - ) - distill: bool = field( - default=False, - metadata={"help": "Select if training with distillation."}, - ) - - -@dataclass -class DataArguments: - dataset: str = field( - default="Daring-Anteater", - metadata={"help": "Specify the dataset.", "choices": ["Daring-Anteater"]}, - ) - train_size: int = field( - default=0, - metadata={"help": "Number of training samples to use. If `0`, use default training size."}, - ) - eval_size: int = field( - default=0, - metadata={ - "help": "Number of evaluation samples to use. If `0`, use default evaluation size." - }, - ) - - -@dataclass -class QuantizationArguments: - quant_cfg: str | None = field( - default=None, - metadata={ - "help": ( - "Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled" - " with the specified quantization format" - ), - "choices": mtq.config.choices | CUSTOM_QUANT_CFG.keys(), - }, - ) - calib_size: int = field( - default=512, - metadata={ - "help": ( - "Specify the calibration size for quantization. The calibration dataset is used to" - " setup the quantization scale parameters for PTQ/QAT." - ) - }, - ) - compress: bool = field( - default=False, - metadata={ - "help": ( - "Whether to compress the model weights after quantization. " - "This is useful for reducing the model size." - ) - }, - ) - - -def train(): - parser = transformers.HfArgumentParser( - (ModelArguments, TrainingArguments, DataArguments, QuantizationArguments) - ) - model_args, training_args, data_args, quant_args = parser.parse_args_into_dataclasses() - print_rank_0(f"arguments: {model_args}, {training_args}, {data_args}, {quant_args}") - - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - print_rank_0(f"Last checkpoint detected: {last_checkpoint}") - - model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, cache_dir=training_args.cache_dir, dtype=torch.bfloat16 - ) - model.generation_config.do_sample = True - tokenizer = transformers.AutoTokenizer.from_pretrained( - model_args.model_name_or_path, model_max_length=training_args.model_max_length - ) - tokenizer.pad_token_id = tokenizer.eos_token_id - - # We set model.config.use_cache to False for training when gradient_checkpointing=False. - # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file. - model.config.use_cache = False - - print_rank_0("Loading dataset...") - data_module = make_supervised_data_module( - dataset=data_args.dataset, - tokenizer=tokenizer, - train_size=data_args.train_size, - eval_size=data_args.eval_size, - ) - - # Ensure calibration size doesn't exceed evaluation dataset size - eval_dataset_size = len(data_module["eval_dataset"]) - if quant_args.calib_size > eval_dataset_size: - warn( - f"{quant_args.calib_size=} is larger than {eval_dataset_size=}. Setting calib_size to {eval_dataset_size}." - ) - quant_args.calib_size = eval_dataset_size - - # Training - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint - - if checkpoint is not None and training_args.lora: - raise RuntimeError("Does not support LoRA resuming training yet!") - - # Torch >= 2.4 throws an error if `use_reentrant` is not set explicitly - if training_args.gradient_checkpointing and training_args.gradient_checkpointing_kwargs is None: - training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - - if quant_args.quant_cfg is not None: - quant_args.quant_cfg = ( - CUSTOM_QUANT_CFG[quant_args.quant_cfg] - if quant_args.quant_cfg in CUSTOM_QUANT_CFG - else getattr(mtq, quant_args.quant_cfg) - ) - distill_kwargs = {} - if training_args.distill: - assert model_args.teacher_model is not None, "Teacher model is required for distillation." - - teacher_model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.teacher_model, - cache_dir=training_args.cache_dir, - dtype=torch.bfloat16, - ) - distill_config = { - "teacher_model": teacher_model, - "criterion": LMLogitsLoss(), - } - distill_kwargs["distill_config"] = distill_config - trainer_cls = QADTrainer if training_args.distill else QATTrainer - - if training_args.lora: - training_args.lora_config = get_lora_config() - - trainer = trainer_cls( - model=model, - processing_class=tokenizer, - args=training_args, - quant_args=quant_args, - **distill_kwargs, - **data_module, - ) - - # There could be GPU memory leak during QAT causing OOM. This is a workaround to fix it. - monkey_patch_training_step_to_fix_memory_leak(trainer) - - if training_args.do_train: - trainer.train(resume_from_checkpoint=checkpoint) - print_rank_0("Training completed.") - - if training_args.do_eval: - metrics = trainer.evaluate() - metrics = get_metrics_with_perplexity(metrics) - print_rank_0(f"Evaluation results: \n{metrics}") - - if training_args.do_train or quant_args.quant_cfg is not None: - print_rank_0("Saving the model...") - trainer.save_state() - trainer.save_model(training_args.output_dir) - - -if __name__ == "__main__": - train() diff --git a/examples/llm_qat/notebooks/QAT_QAD_Walkthrough.ipynb b/examples/llm_qat/notebooks/QAT_QAD_Walkthrough.ipynb index f52d596f7c3..900b3c81c10 100644 --- a/examples/llm_qat/notebooks/QAT_QAD_Walkthrough.ipynb +++ b/examples/llm_qat/notebooks/QAT_QAD_Walkthrough.ipynb @@ -149,23 +149,7 @@ "id": "b6af94af-1de6-4cb1-959b-98fb3f4e1932", "metadata": {}, "outputs": [], - "source": [ - "from trl import ModelConfig\n", - "\n", - "model_args = ModelConfig(\n", - " model_name_or_path=model_name,\n", - " attn_implementation=\"eager\",\n", - " torch_dtype=\"bfloat16\",\n", - ")\n", - "model_kwargs = {\n", - " \"revision\": model_args.model_revision,\n", - " \"trust_remote_code\": model_args.trust_remote_code,\n", - " \"attn_implementation\": model_args.attn_implementation,\n", - " \"torch_dtype\": model_args.torch_dtype,\n", - " \"use_cache\": False,\n", - " \"device_map\": \"auto\",\n", - "}" - ] + "source": "from trl import ModelConfig\n\nmodel_args = ModelConfig(\n model_name_or_path=model_name,\n attn_implementation=\"eager\",\n torch_dtype=\"bfloat16\",\n)\nmodel_kwargs = {\n \"revision\": model_args.model_revision,\n \"trust_remote_code\": model_args.trust_remote_code,\n \"attn_implementation\": model_args.attn_implementation,\n \"dtype\": model_args.torch_dtype,\n \"use_cache\": False,\n \"device_map\": \"auto\",\n}" }, { "cell_type": "markdown", @@ -560,11 +544,7 @@ "cell_type": "markdown", "id": "10acc50c-c876-41d5-8f7e-00dab8842ccd", "metadata": {}, - "source": [ - "**Note:** The QAT checkpoint for `nvfp4` config can be created by using `--quant_cfg NVFP4_DEFAULT_CFG` in QAT example.\n", - "\n", - "See more details on deployment of quantized model [here](https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/llm_ptq/README.md)." - ] + "source": "**Note:** The QAT checkpoint for `nvfp4` config can also be created using the CLI scripts. See the [QAT README](../README.md) for the full end-to-end workflow using `quantize.py`, `train.py`, and `export.py`.\n\nSee more details on deployment of quantized model [here](https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/llm_ptq/README.md)." }, { "cell_type": "markdown", @@ -850,4 +830,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/llm_qat/quantize.py b/examples/llm_qat/quantize.py new file mode 100644 index 00000000000..66b83a5a7e4 --- /dev/null +++ b/examples/llm_qat/quantize.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone quantization script for LLMs using ModelOpt recipes. + +This script applies post-training quantization (PTQ) to a model and saves the +quantized checkpoint. The quantized model can then be used for QAT/QAD training +with train.py or exported with export.py. + +Usage: + python quantize.py \ + --model_name_or_path meta-llama/Meta-Llama-3-8B \ + --dataset_config configs/dataset/blend.yaml \ + --recipe general/ptq/nvfp4_default-kv_fp8 \ + --output_dir llama3-quantized +""" + +import os + +import torch +import transformers +from arguments import get_quantize_args +from utils import make_supervised_data_module + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.plugins.transformers_trainer import resolve_quant_cfg_from_args +from modelopt.torch.utils import print_rank_0 + +# Enable automatic save/load of modelopt state with huggingface checkpointing +mto.enable_huggingface_checkpointing() + + +def _build_calib_dataloader(tokenizer, data_args, quant_args): + """Build a calibration dataloader from the train dataset.""" + print_rank_0("Loading calibration dataset...") + data_module = make_supervised_data_module(data_args, tokenizer) + train_dataset = data_module["train_dataset"] + num_samples = min(quant_args.calib_size, len(train_dataset)) + calib_dataset = torch.utils.data.Subset(train_dataset, list(range(num_samples))) + return torch.utils.data.DataLoader( + calib_dataset, + batch_size=quant_args.calib_batch_size, + collate_fn=data_module["data_collator"], + ) + + +def quantize(): + model_args, data_args, quant_args = get_quantize_args() + + if quant_args.recipe: + print_rank_0(f"Loading quantization recipe: {quant_args.recipe}") + ptq_cfg = resolve_quant_cfg_from_args(quant_args) + if ptq_cfg is None: + raise ValueError("--recipe or --quant_cfg is required for quantization.") + + # Load model and tokenizer + print_rank_0(f"Loading model: {model_args.model_name_or_path}") + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + dtype=torch.bfloat16, + device_map="auto", + ) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, model_max_length=model_args.model_max_length + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + calib_dataloader = _build_calib_dataloader(tokenizer, data_args, quant_args) + + def forward_loop(model): + for batch in calib_dataloader: + batch = {k: v.to(model.device) for k, v in batch.items()} + model(**batch) + + # Quantize + print_rank_0("Quantizing the model...") + mtq.quantize(model, ptq_cfg, forward_loop) + mtq.print_quant_summary(model) + + if quant_args.compress: + print_rank_0("Compressing model weights for QLoRA...") + mtq.compress(model) + + # Save quantized checkpoint + os.makedirs(quant_args.output_dir, exist_ok=True) + print_rank_0(f"Saving quantized model to {quant_args.output_dir}") + model.save_pretrained(quant_args.output_dir) + tokenizer.save_pretrained(quant_args.output_dir) + + +if __name__ == "__main__": + quantize() diff --git a/examples/llm_qat/requirements.txt b/examples/llm_qat/requirements.txt index b8da4e088f5..d7757e37c7e 100644 --- a/examples/llm_qat/requirements.txt +++ b/examples/llm_qat/requirements.txt @@ -1,3 +1,3 @@ flash-attn py7zr -tensorboardX +tensorboard diff --git a/examples/llm_qat/simple_qat_train.py b/examples/llm_qat/simple_qat_train.py index 8531027845f..e3c3231494c 100644 --- a/examples/llm_qat/simple_qat_train.py +++ b/examples/llm_qat/simple_qat_train.py @@ -18,23 +18,28 @@ import torch import torch.nn as nn +from dataset_utils import build_blend_dataset, load_blend_config from torch.optim import AdamW from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer -from utils import get_daring_anteater import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.recipe import ModelOptPTQRecipe, load_recipe def get_dataloader(args, tokenizer): - train_dataset = get_daring_anteater( - tokenizer, "train", args.max_length, args.train_size, args.calib_size - ) - calib_dataset = get_daring_anteater( - tokenizer, "test", args.max_length, args.train_size, args.calib_size - ) + config = load_blend_config("configs/dataset/blend.yaml") + ds = build_blend_dataset(config, tokenizer, args.max_length) + + train_dataset = ds["train"] + if 0 < args.train_size < len(train_dataset): + train_dataset = train_dataset.select(range(args.train_size)) + + calib_dataset = ds["eval"] + if 0 < args.calib_size < len(calib_dataset): + calib_dataset = calib_dataset.select(range(args.calib_size)) def collate_fn(batch): return { @@ -59,7 +64,9 @@ def train(model, optimizer, train_dataloader, tokenizer, epochs, output_dir, dev inputs = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) - outputs = model(input_ids=inputs, attention_mask=attention_mask, labels=inputs) + outputs = model( + input_ids=inputs, attention_mask=attention_mask, labels=batch["labels"].to(device) + ) loss = outputs.loss optimizer.zero_grad() @@ -85,11 +92,10 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--epochs", type=int, default=2, help="Number of epochs") parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate") parser.add_argument( - "--quant-cfg", + "--recipe", type=str, - default="NVFP4_DEFAULT_CFG", - choices=mtq.config.choices, - help="Quantization configuration", + default="general/ptq/nvfp4_default-kv_fp8", + help="Path to a quantization recipe YAML (built-in or custom)", ) # Reproducibility parser.add_argument("--seed", type=int, default=42, help="Random seed") @@ -118,10 +124,16 @@ def main() -> None: # Calibrate the model def calibrate(m: nn.Module): for batch in calib_dataloader: - m(batch["input_ids"].to(device)) - - # Quantize the model - model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate) + m( + input_ids=batch["input_ids"].to(device), + attention_mask=batch["attention_mask"].to(device), + ) + + # Load recipe and quantize the model + recipe = load_recipe(args.recipe) + if not isinstance(recipe, ModelOptPTQRecipe): + raise ValueError(f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}") + model = mtq.quantize(model, recipe.quantize, calibrate) # Initialize optimizer optimizer = AdamW(model.parameters(), lr=args.lr) diff --git a/examples/llm_qat/train.py b/examples/llm_qat/train.py new file mode 100644 index 00000000000..7babf05adef --- /dev/null +++ b/examples/llm_qat/train.py @@ -0,0 +1,153 @@ +# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py + +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QAT/QAD training script for pre-quantized LLMs. + +The model should be pre-quantized using quantize.py before running this script. + +Usage: + accelerate launch --config-file configs/accelerate/fsdp2.yaml train.py \ + --config configs/train/qat_nvfp4.yaml +""" + +import os + +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + +from warnings import warn + +import torch +import transformers +from arguments import get_training_args +from transformers.trainer_utils import get_last_checkpoint +from utils import get_lora_config, get_metrics_with_perplexity, make_supervised_data_module + +import modelopt.torch.opt as mto +from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer, QATTrainer +from modelopt.torch.utils import print_rank_0 + +# Enable automatic save/load of modelopt state huggingface checkpointing +mto.enable_huggingface_checkpointing() + + +def train(): + model_args, training_args, data_args, distill_args = get_training_args() + + if distill_args.distill and getattr(training_args, "fsdp_config", None): + fsdp_cfg = training_args.fsdp_config + if fsdp_cfg.get("fsdp_cpu_ram_efficient_loading", True): + warn( + "Distillation with FSDP2 may require --fsdp_cpu_ram_efficient_loading False. " + "Set this if you encounter issues loading the teacher model." + ) + + print_rank_0(f"arguments: {model_args}, {training_args}, {data_args}, {distill_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + print_rank_0(f"Last checkpoint detected: {last_checkpoint}") + + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + dtype=torch.bfloat16, + ) + model.generation_config.do_sample = True + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, model_max_length=model_args.model_max_length + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + # We set model.config.use_cache to False for training when gradient_checkpointing=False. + # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file. + model.config.use_cache = False + + print_rank_0("Loading dataset...") + data_module = make_supervised_data_module(data_args, tokenizer) + + # Training + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + if checkpoint is not None and training_args.lora: + raise RuntimeError("Does not support LoRA resuming training yet!") + + # Torch >= 2.4 throws an error if `use_reentrant` is not set explicitly + if training_args.gradient_checkpointing and training_args.gradient_checkpointing_kwargs is None: + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + distill_kwargs = {} + if distill_args.distill: + if distill_args.teacher_model is None: + raise ValueError("--teacher_model is required when --distill is enabled.") + + teacher_model = transformers.AutoModelForCausalLM.from_pretrained( + distill_args.teacher_model, + cache_dir=training_args.cache_dir, + dtype=torch.bfloat16, + ) + distill_kwargs = distill_args.to_distill_kwargs(teacher_model) + trainer_cls = QADTrainer if distill_args.distill else QATTrainer + + if training_args.lora: + training_args.lora_config = get_lora_config() + + trainer = trainer_cls( + model=model, + processing_class=tokenizer, + args=training_args, + **distill_kwargs, + **data_module, + ) + + if training_args.do_train: + trainer.train(resume_from_checkpoint=checkpoint) + print_rank_0("Training completed.") + + if training_args.do_eval: + metrics = trainer.evaluate() + metrics = get_metrics_with_perplexity(metrics) + print_rank_0(f"Evaluation results: \n{metrics}") + + if training_args.do_train: + print_rank_0("Saving the model...") + trainer.save_state() + trainer.save_model(training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/examples/llm_qat/utils.py b/examples/llm_qat/utils.py index bb70bdf1221..19e874e55fc 100644 --- a/examples/llm_qat/utils.py +++ b/examples/llm_qat/utils.py @@ -13,130 +13,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc -import types -from contextlib import contextmanager -from functools import partial - -import datasets import torch import transformers from peft import LoraConfig, TaskType from transformers import default_data_collator -IGNORE_INDEX = -100 - - -@contextmanager -def main_process_first(): - """Context manager to run code on the main process first.""" - if not torch.distributed.is_initialized(): - yield - return - - rank = torch.distributed.get_rank() - if rank == 0: - yield - torch.distributed.barrier() - else: - torch.distributed.barrier() - yield - torch.distributed.barrier() - - -def get_daring_anteater( - tokenizer: transformers.AutoTokenizer, - split="train", - max_length=4096, - train_size=0, - eval_size=0, -): - # sample = { - # 'system': '{system message}', - # 'conversations': [ - # {'from': 'User', 'value': '{turn 1 user message}', 'label': None}, - # {'from': 'Assistant', 'value': '{turn 1 assistant message}', 'label': '{turn 1 assistant label}'}, - # {'from': 'User', 'value': '{turn 2 user message}', 'label': None}, - # {'from': 'Assistant', 'value': '{turn 2 assistant message}', 'label': '{turn 2 assistant label}'}, - # ], - # "mask": "User", - # "type": "VALUE_TO_TEXT", - # } - - def process_and_tokenize(sample): - conversations = sample["conversations"] - all_input_ids = [tokenizer.bos_token_id] if tokenizer.bos_token_id else [] - all_labels = [IGNORE_INDEX] if tokenizer.bos_token_id else [] - - for conversation in conversations: - role = conversation["from"] - input_ids = tokenizer.encode(conversation["value"] + "\n", add_special_tokens=False) - labels = input_ids if role == "Assistant" else [IGNORE_INDEX] * len(input_ids) - - all_input_ids.extend(input_ids) - all_labels.extend(labels) - - if len(all_input_ids) > max_length: - break - all_input_ids.append(tokenizer.eos_token_id) - all_labels.append(IGNORE_INDEX) - all_attention_mask = [1] * len(all_input_ids) - - cur_seq_length = len(all_input_ids) - if cur_seq_length < max_length: - pad_token = ( - tokenizer.pad_token_id - if tokenizer.pad_token_id is not None - else tokenizer.eos_token_id - ) - all_input_ids += [pad_token] * (max_length - cur_seq_length) - all_attention_mask += [0] * (max_length - cur_seq_length) - all_labels += [IGNORE_INDEX] * (max_length - cur_seq_length) - - return { - "input_ids": all_input_ids[:max_length], - "attention_mask": all_attention_mask[:max_length], - "labels": all_labels[:max_length], - } +def make_supervised_data_module( + data_args, + tokenizer: transformers.PreTrainedTokenizer, +) -> dict: + """Build train/eval datasets and a default collator.""" + from dataset_utils import build_blend_dataset, load_blend_config + + config = load_blend_config(data_args.dataset_config) + max_length = getattr(tokenizer, "model_max_length", 4096) + + ds = build_blend_dataset( + config, + tokenizer, + max_length, + seed=data_args.dataset_seed, + cache_dir=data_args.dataset_cache_dir, + shuffle=data_args.shuffle, + shuffle_buffer=data_args.shuffle_buffer, + num_proc=data_args.num_proc, + ) - if hasattr(get_daring_anteater, "cached_dataset"): - dataset = get_daring_anteater.cached_dataset - else: - with main_process_first(): - dataset = datasets.load_dataset("nvidia/Daring-Anteater", split="train") - # Shuffle and subsample the dataset - eval_size = 2000 if eval_size == 0 else eval_size - train_size = len(dataset) - eval_size if train_size == 0 else train_size - assert train_size + eval_size <= len(dataset) and train_size > 0 and eval_size > 0, ( - "not enough data for train-eval split" - ) - dataset = dataset.shuffle(seed=42).select(range(train_size + eval_size)) - dataset = dataset.map(process_and_tokenize, remove_columns=list(dataset.features)) - dataset = dataset.train_test_split(test_size=eval_size, shuffle=True, seed=42) - get_daring_anteater.cached_dataset = dataset - return dataset[split] + train_ds = ds["train"] + if data_args.train_samples > 0 and data_args.train_samples < len(train_ds): + train_ds = train_ds.select(range(data_args.train_samples)) + eval_ds = ds["eval"] + if data_args.eval_samples > 0 and data_args.eval_samples < len(eval_ds): + eval_ds = eval_ds.select(range(data_args.eval_samples)) -def make_supervised_data_module( - dataset="Daring-Anteater", - tokenizer: transformers.PreTrainedTokenizer = None, - train_size: int = 0, - eval_size: int = 0, -) -> dict: - """Make dataset and collmtor for supervised fine-tuning.""" - if dataset == "Daring-Anteater": - train_dataset = get_daring_anteater( - tokenizer, "train", tokenizer.model_max_length, train_size, eval_size - ) - val_dataset = get_daring_anteater( - tokenizer, "test", tokenizer.model_max_length, train_size, eval_size - ) - else: - raise ValueError(f"Dataset {dataset} not supported") return { - "train_dataset": train_dataset, - "eval_dataset": val_dataset, + "train_dataset": train_ds, + "eval_dataset": eval_ds, "data_collator": default_data_collator, } @@ -157,18 +71,6 @@ def get_lora_config(): ) -def monkey_patch_training_step_to_fix_memory_leak(trainer): - def new_func(original_f_name, trainer, *args, **kwargs): - gc.collect() - return getattr(trainer, original_f_name)(*args, **kwargs) - - for f_name in ["training_step", "prediction_step", "_load_best_model"]: - setattr(trainer, "_original_" + f_name, getattr(trainer, f_name)) - setattr( - trainer, f_name, types.MethodType(partial(new_func, "_original_" + f_name), trainer) - ) - - def get_metrics_with_perplexity(metrics): """Add perplexity to the metrics.""" if "eval_loss" in metrics: diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 4e23a56a288..d9995c9d7fb 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -79,7 +79,7 @@ python ../llm_ptq/hf_ptq.py \ This creates `/vllm_fq_modelopt_state.pth` (ModelOpt quantizer state for vLLM fake-quant reload) and saves the HF-exported model under `` (config/tokenizer/weights). - Note: `--pyt_ckpt_path` can point to either an HF checkpoint or a ModelOpt-saved checkpoint (e.g., a QAT/QAD checkpoint produced by `examples/llm_qat/main.py`). If the input checkpoint is already quantized, the script will **skip re-quantization** and only export artifacts for vLLM fakequant reload. + Note: `--pyt_ckpt_path` can point to either an HF checkpoint or a ModelOpt-saved checkpoint (e.g., a QAT/QAD checkpoint produced by `examples/llm_qat/train.py`). If the input checkpoint is already quantized, the script will **skip re-quantization** and only export artifacts for vLLM fakequant reload. - For **MCore** models, export the model with flag `--export-vllm-fq` as described in [Megatron-LM README](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt#-nvfp4-quantization-qauntization-aware-training-and-model-export). This generates `quantizer_state.pth`, which contains quantizer tensors for vLLM reload via `QUANT_FILE_PATH`. diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index c865d885717..f4202d10fca 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -15,16 +15,55 @@ """ModelOpt plugin to train HuggingFace models with knowledge distillation.""" +from dataclasses import field + from torch import Tensor from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.trainer_pt_utils import LabelSmoother import modelopt.torch.distill as mtd from modelopt.torch.opt.plugins import ModelOptHFTrainer +from modelopt.torch.opt.plugins.transformers import ModelOptHFArguments from modelopt.torch.utils import print_rank_0 IGNORE_TOKEN_ID = LabelSmoother.ignore_index # equals -100 +_SUPPORTED_CRITERIA = {"logits_loss"} + + +class DistillArguments(ModelOptHFArguments): + """Distillation arguments for knowledge distillation training.""" + + distill: bool = field( + default=False, + metadata={"help": "Enable training with knowledge distillation."}, + ) + teacher_model: str | None = field( + default=None, + metadata={"help": "The name or path of the teacher model to use for distillation."}, + ) + criterion: str = field( + default="logits_loss", + metadata={ + "help": "Distillation loss criterion. Currently only 'logits_loss' is supported." + }, + ) + + def to_distill_kwargs(self, teacher_model) -> dict: + """Convert distill args to kwargs for KDTrainer/QADTrainer. + + Args: + teacher_model: The loaded teacher model instance. + + Returns: + Dict with ``distill_config`` ready to pass to the trainer. + """ + if self.criterion not in _SUPPORTED_CRITERIA: + raise ValueError( + f"Unsupported criterion: {self.criterion!r}. Supported: {_SUPPORTED_CRITERIA}" + ) + return {"distill_config": {"teacher_model": teacher_model, "criterion": LMLogitsLoss()}} + class KDTrainer(ModelOptHFTrainer): """Distillation trainer for HuggingFace models.""" diff --git a/modelopt/torch/opt/plugins/transformers.py b/modelopt/torch/opt/plugins/transformers.py index 3370309156d..f94fd6dd1b1 100644 --- a/modelopt/torch/opt/plugins/transformers.py +++ b/modelopt/torch/opt/plugins/transformers.py @@ -15,14 +15,17 @@ """ModelOpt plugin for enabling automatic save/restore of ModelOpt state for HuggingFace models.""" +import dataclasses import os +import sys import types from contextlib import contextmanager +from pathlib import Path import torch import transformers from packaging.version import Version -from transformers import PreTrainedModel, Trainer, TrainerCallback +from transformers import HfArgumentParser, PreTrainedModel, Trainer, TrainerCallback from transformers import modeling_utils as tf_modeling_utils from modelopt.torch.utils import report_memory @@ -36,7 +39,7 @@ register_for_patching, ) -__all__ = ["ModelOptHFTrainer"] +__all__ = ["ModelOptArgParser", "ModelOptHFArguments", "ModelOptHFTrainer"] @contextmanager @@ -162,6 +165,175 @@ def _load_params_and_buffers_into_zero3_model(model_to_load, state_dict, load_co ) +@dataclasses.dataclass +class ModelOptHFArguments: + """Base for all ModelOpt argument dataclasses used with :class:`ModelOptArgParser`. + + Subclasses are automatically treated as dataclasses (no ``@dataclass`` decorator needed). + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + dataclasses.dataclass(cls) + + +class ModelOptArgParser(HfArgumentParser): + """HfArgumentParser with ``--config`` YAML support and ``--generate_docs`` for ARGUMENTS.md.""" + + def parse_args_into_dataclasses(self, args=None, **kwargs): + """Parse args with optional YAML config defaults and doc generation.""" + if args is None: + args = list(sys.argv[1:]) + + # --generate_docs [output_path]: generate markdown and exit + if "--generate_docs" in args: + idx = args.index("--generate_docs") + output = ( + args[idx + 1] + if idx + 1 < len(args) and not args[idx + 1].startswith("--") + else "ARGUMENTS.md" + ) + self._generate_docs(output) + sys.exit(0) + + # --config : load YAML as defaults, CLI args override + if "--config" in args: + idx = args.index("--config") + if idx + 1 >= len(args): + raise ValueError("--config requires a path argument") + config_path = args[idx + 1] + args = args[:idx] + args[idx + 2 :] # strip --config from argv + import yaml + + with open(config_path) as f: + config = yaml.safe_load(f) + if config: + known_by_parser = {a.dest for a in self._actions} + all_modelopt_fields = self._all_modelopt_fields() + applicable = {} + for k, v in config.items(): + if k in known_by_parser: + applicable[k] = v + elif k not in all_modelopt_fields: + raise ValueError( + f"Unknown config key '{k}' in {config_path}. " + f"Not recognized by any ModelOptHFArguments subclass." + ) + self.set_defaults(**applicable) + + return super().parse_args_into_dataclasses(args=args, **kwargs) + + @staticmethod + def _all_modelopt_fields(): + """Collect all field names from every ModelOptHFArguments subclass.""" + fields = set() + queue = list(ModelOptHFArguments.__subclasses__()) + while queue: + cls = queue.pop() + if dataclasses.is_dataclass(cls): + fields.update(f.name for f in dataclasses.fields(cls)) + queue.extend(cls.__subclasses__()) + return fields + + def _generate_docs(self, output_path: str) -> None: + """Generate a markdown argument reference from registered dataclass types.""" + regen_cmd = f"python {sys.argv[0]} --generate_docs {output_path}" + lines = [ + "# Argument Reference", + "", + f"", + "", + ] + + # Sort: modelopt library classes first, then example-specific + def _sort_key(dc): + mod = dc.__module__ or "" + return (0 if mod.startswith("modelopt.") else 1, mod, dc.__name__) + + sorted_types = sorted(self.dataclass_types, key=_sort_key) + + # Fields belonging to HF TrainingArguments (used to detect "own" fields) + hf_training_fields: set[str] = set() + if hasattr(transformers, "TrainingArguments"): + hf_training_fields = { + f.name for f in dataclasses.fields(transformers.TrainingArguments) + } + + for dc in sorted_types: + group_name = dc.__name__ + lines.append(f"## {group_name}") + lines.append("") + + is_hf_subclass = ( + hasattr(transformers, "TrainingArguments") + and issubclass(dc, transformers.TrainingArguments) + and dc is not transformers.TrainingArguments + ) + + if is_hf_subclass: + own_fields = [f for f in dataclasses.fields(dc) if f.name not in hf_training_fields] + lines.append( + "Extends [HuggingFace TrainingArguments]" + "(https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)." + " Only additional arguments are shown below." + ) + lines.append("") + else: + own_fields = list(dataclasses.fields(dc)) + + if not own_fields: + lines.append("_No additional arguments._") + lines.append("") + continue + + lines.append("| Argument | Type | Default | Description |") + lines.append("|----------|------|---------|-------------|") + + for f in own_fields: + name = f"`--{f.name}`" + type_str = self._format_type(f.type) + default_str = self._format_default(f.default, f.default_factory) + help_text = dict(f.metadata).get("help", "") + # Collapse multi-line help into single line + help_text = " ".join(help_text.split()) + lines.append(f"| {name} | {type_str} | {default_str} | {help_text} |") + + lines.append("") + + # Remove trailing blank lines so markdownlint won't modify the file + while lines and lines[-1] == "": + lines.pop() + Path(output_path).write_text("\n".join(lines) + "\n") + print(f"Generated {output_path}") + + @staticmethod + def _format_type(type_hint) -> str: + """Format a type hint for display in markdown.""" + s = str(type_hint) + # Clean up common type representations + for old, new in [ + ("typing.", ""), + ("typing_extensions.", ""), + ("", ""), + ]: + s = s.replace(old, new) + return f"`{s}`" + + @staticmethod + def _format_default(default, default_factory) -> str: + """Format a default value for display in markdown.""" + if default is not dataclasses.MISSING: + if default is None: + return "`None`" + if isinstance(default, str): + return f'`"{default}"`' + return f"`{default}`" + if default_factory is not dataclasses.MISSING: + return "_factory_" + return "_required_" + + def _report_memory(msg): if not torch.cuda.is_available(): return diff --git a/modelopt/torch/quantization/plugins/transformers_trainer.py b/modelopt/torch/quantization/plugins/transformers_trainer.py index 25363278439..202def592f7 100644 --- a/modelopt/torch/quantization/plugins/transformers_trainer.py +++ b/modelopt/torch/quantization/plugins/transformers_trainer.py @@ -20,7 +20,8 @@ import json import os import types -from dataclasses import dataclass, field +import warnings +from dataclasses import field import torch from tqdm import tqdm @@ -29,6 +30,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.distill.plugins.huggingface import KDTrainer from modelopt.torch.opt.plugins import ModelOptHFTrainer +from modelopt.torch.opt.plugins.transformers import ModelOptHFArguments from modelopt.torch.utils import get_module_device, print_rank_0 from ..config import QuantizeConfig @@ -44,21 +46,25 @@ # TODO: Enable documentation rendering for this class -@dataclass -class QuantizationArguments: - """Quantization arguments for quantization aware training. +class QuantizationArguments(ModelOptHFArguments): + """Quantization arguments for ModelOpt Hugging Face trainer integrations.""" - This classes is intended to be used with ModelOpt's QAT/QAD trainers for HuggingFace models. - This class can also be used to parse the quantization arguments - from the command line to the taining script. - """ - - quant_cfg: str | None = field( + recipe: str | None = field( + default=None, + metadata={ + "help": ( + "Path to a quantization recipe YAML file (built-in or custom). " + "Built-in recipes can be specified by relative path, e.g. " + "'general/ptq/nvfp4_default-kv_fp8'. Replaces the deprecated --quant_cfg flag." + ), + }, + ) + quant_cfg: str | QuantizeConfig | None = field( default=None, metadata={ "help": ( - "Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled" - " with the specified quantization format" + "Deprecated: pre-quantize the model with a separate quantization step instead. " + "Specify the quantization format for PTQ/QAT by name (e.g. NVFP4_DEFAULT_CFG)." ), }, ) @@ -82,22 +88,37 @@ class QuantizationArguments: ) -class QuantizationArgumentsWithConfig(QuantizationArguments): - """Quantization arguments for quantization aware training with config. +def resolve_quant_cfg_from_args( + quant_args: QuantizationArguments | None, + *, + warn_on_quant_cfg: bool = False, +): + """Resolve a ModelOpt quantization config from recipe or legacy quant_cfg arguments.""" + if quant_args is None: + return None - This class is intended to be used with ModelOpt's QAT/QAD trainers for HuggingFace models, - however, it cannot be used for command line parsing. - """ + recipe_path = getattr(quant_args, "recipe", None) + if recipe_path: + from modelopt.recipe import ModelOptPTQRecipe, load_recipe - quant_cfg: str | QuantizeConfig | None = field( - default=None, - metadata={ - "help": ( - "Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled" - " with the specified quantization format" - ), - }, - ) + recipe = load_recipe(recipe_path) + if not isinstance(recipe, ModelOptPTQRecipe): + raise ValueError( + f"Expected PTQ recipe, but got {type(recipe).__name__} from {recipe_path}" + ) + return recipe.quantize + + quant_cfg = getattr(quant_args, "quant_cfg", None) + if quant_cfg is None: + return None + if warn_on_quant_cfg: + warnings.warn( + "In-trainer quantization via quant_args is deprecated and will be removed in a " + "future release. Pre-quantize your model with a separate quantization step instead.", + DeprecationWarning, + stacklevel=3, + ) + return getattr(mtq, quant_cfg) if isinstance(quant_cfg, str) else quant_cfg def _patch_fsdp2_post_backward(): @@ -164,28 +185,21 @@ class QATTrainer(ModelOptHFTrainer): """A drop-in replacement of HuggingFace's Trainer for quantization aware training with ModelOpt. This class takes an additional optional argument `quant_args` of type - :class:`QuantizationArgumentsWithConfig ` + :class:`QuantizationArguments ` to specify the quantization arguments. """ def __init__( self, *args, - quant_args: QuantizationArgumentsWithConfig | QuantizationArguments | None = None, + quant_args: QuantizationArguments | None = None, **kwargs, ): """Initialize the trainer with modelopt states.""" super().__init__(*args, **kwargs) self.quant_args = quant_args - quant_cfg = None - if quant_args is not None and getattr(quant_args, "quant_cfg", None): - quant_cfg = ( - getattr(mtq, quant_args.quant_cfg) - if isinstance(quant_args.quant_cfg, str) - else quant_args.quant_cfg - ) - self.quant_cfg = quant_cfg + self.quant_cfg = resolve_quant_cfg_from_args(quant_args, warn_on_quant_cfg=True) # Add lora adapter before quantizing the model if getattr(self.args, "lora_config", None) is not None and not hasattr( @@ -209,7 +223,7 @@ def __init__( self._patch_accelerate_for_fsdp2_fix() self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth") - if os.path.exists(self._modelopt_state_path): + if os.path.exists(self._modelopt_state_path) and not is_quantized(self.model): self._restore_modelopt_state_with_weights() elif is_quantized(self.model): self._save_modelopt_state_with_weights() @@ -232,6 +246,7 @@ def _save_modelopt_state_with_weights(self): print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}") def _restore_modelopt_state_with_weights(self): + """Restore the modelopt state with weights.""" modelopt_state = mto.load_modelopt_state(self._modelopt_state_path) modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) mto.restore_from_modelopt_state(self.model, modelopt_state) @@ -257,7 +272,7 @@ def forward_loop(model): # TODO: Remove calibrate_with_adapters - this should not be needed with calibrate_with_adapters(self.model, self.args): print_rank_0("Quantizing the model...") - mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type] + mtq.quantize(self.model, self.quant_cfg, forward_loop) # Save modelopt state self._save_modelopt_state_with_weights() @@ -276,12 +291,14 @@ def forward_loop(model): def training_step(self, *args, **kwargs): """Training step.""" + gc.collect() if self.quant_cfg is not None and not is_quantized(self.model): self._quantize_model() return super().training_step(*args, **kwargs) def prediction_step(self, *args, **kwargs): """Prediction step.""" + gc.collect() if self.quant_cfg is not None and not is_quantized(self.model): self._quantize_model() return super().prediction_step(*args, **kwargs) @@ -332,6 +349,7 @@ def save_model(self, *args, **kwargs): def _load_best_model(self, *args, **kwargs): """Load the best model for final evaluation.""" + gc.collect() is_lora = getattr(self.args, "lora", None) if is_lora and not self.is_fsdp_enabled: # Custom logic for loading best model with LoRA @@ -378,12 +396,7 @@ def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> N print_rank_0(f"[warn] Failed to update dtype in config.json: {e}") def _patch_accelerate_for_fsdp2_fix(self): - """Fixes for accelerate prepare. - - Accelerate fsdp2 prepare assumes that all parameters and buffers are sharded. This assumption - is causing issues with quantized models since quantization modules adds buffers which are not sharded. - This patch hides the buffers added by quantization modules from the original accelerate prepare. - """ + """Patch accelerate FSDP2 prepare for TensorQuantizer buffers.""" _patch_fsdp2_post_backward() def _modelopt_prepare(self, *args, **kwargs): @@ -394,18 +407,39 @@ def _modelopt_prepare(self, *args, **kwargs): if model is None: return self._original_prepare(*args, **kwargs) + # Hide TQ buffers from accelerate's FSDP2 state_dict handling. tq_og_non_prsist_buffers = {} for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): - tq.to_empty(device=self.device) + # With fsdp_cpu_ram_efficient_loading=true, non-rank-0 processes + # hold meta-device buffers which cannot be moved with .to(). + # Allocate empty tensors on the target device for those; real + # values are broadcast from rank 0 after _original_prepare below. + for name, buf in list(tq._buffers.items()): + if buf is None: + continue + tq._buffers[name] = ( + torch.empty_like(buf, device=self.device) + if buf.is_meta + else buf.to(self.device) + ) tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy() tq._non_persistent_buffers_set.update(tq._buffers.keys()) outputs = self._original_prepare(*args, **kwargs) + # Restore original buffer persistence. for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): tq._non_persistent_buffers_set.clear() tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq]) + # Sync TQ buffers across ranks. With cpu_ram_efficient_loading, only rank 0 + # has valid buffer values; other ranks have uninitialized meta-device values. + if torch.distributed.is_initialized(): + for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)): + for buf in tq._buffers.values(): + if buf is not None: + torch.distributed.broadcast(buf, src=0) + return outputs self.accelerator._original_prepare = self.accelerator.prepare diff --git a/modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml b/modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml new file mode 100644 index 00000000000..432b970339c --- /dev/null +++ b/modelopt_recipes/general/ptq/int4_blockwise_weight_only.yaml @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: INT4 blockwise weight-only (W4A16, block size 128), max calibration. +quantize: + algorithm: max + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*weight_quantizer' + cfg: + num_bits: 4 + block_sizes: + -1: 128 + - quantizer_name: '*input_quantizer' + enable: false + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/tests/examples/conftest.py b/tests/examples/conftest.py index 1087cefa2c9..b148e5500ea 100644 --- a/tests/examples/conftest.py +++ b/tests/examples/conftest.py @@ -15,7 +15,11 @@ import pytest -from _test_utils.torch.transformers_models import create_tiny_llama_dir +from _test_utils.torch.transformers_models import ( + create_tiny_gpt_oss_dir, + create_tiny_llama_dir, + create_tiny_qwen3_dir, +) @pytest.fixture(scope="session") @@ -28,3 +32,26 @@ def tiny_llama_path(tmp_path_factory): intermediate_size=512, ) ) + + +@pytest.fixture(scope="session") +def tiny_qwen3_path(tmp_path_factory): + return str( + create_tiny_qwen3_dir( + tmp_path_factory.mktemp("tiny_qwen3"), + with_tokenizer=True, + hidden_size=512, + intermediate_size=512, + ) + ) + + +@pytest.fixture(scope="session") +def tiny_gpt_oss_path(tmp_path_factory): + return str( + create_tiny_gpt_oss_dir( + tmp_path_factory.mktemp("tiny_gpt_oss"), + with_tokenizer=True, + num_hidden_layers=2, + ) + ) diff --git a/tests/examples/gpt_oss/test_gpt_oss_qat.py b/tests/examples/gpt_oss/test_gpt_oss_qat.py index 43464110b25..f36dd9e023e 100644 --- a/tests/examples/gpt_oss/test_gpt_oss_qat.py +++ b/tests/examples/gpt_oss/test_gpt_oss_qat.py @@ -19,8 +19,7 @@ import pytest from _test_utils.examples.run_command import run_example_command from _test_utils.torch.misc import minimum_gpu - -pytestmark = pytest.mark.release(reason="This test is used for QA release.") +from datasets import Dataset, DatasetDict class GPTOSS: @@ -285,6 +284,7 @@ def deploy_gpt_oss_trtllm(self, tmp_path, model_path_override=None): run_example_command(cmd_parts, "gpt-oss") +@pytest.mark.release(reason="This test is used for QA release.") @pytest.mark.parametrize( "model_path", [ @@ -346,3 +346,70 @@ def test_gpt_oss_complete_pipeline(model_path, tmp_path): print("Step 3: Running deployment with MXFP4 checkpoint...") gpt_oss.deploy_gpt_oss_trtllm(tmp_path, model_path_override=mxfp4_checkpoint) print("Step 3 completed: Deployment successful") + + +def test_gpt_oss_sft_toy(tiny_gpt_oss_path, tmp_path): + """CPU smoke test: SFT-only (no quantization) on a tiny gpt-oss model. + + Validates that sft.py parses the TrlParser arguments including the new + QuantizationArguments.recipe field and completes train + save. + """ + dataset_dir = tmp_path / "dataset" + DatasetDict( + { + "train": Dataset.from_dict( + { + "text": [ + "The quick brown fox.", + "Hello world.", + "A short sentence.", + "Another one.", + ] + } + ) + } + ).save_to_disk(str(dataset_dir)) + + output_dir = tmp_path / "sft-toy" + run_example_command( + [ + "python", + "sft.py", + "--model_name_or_path", + tiny_gpt_oss_path, + "--dataset_name", + str(dataset_dir), + "--output_dir", + str(output_dir), + "--attn_implementation", + "eager", + "--max_length", + "64", + "--num_train_epochs", + "1", + "--max_steps", + "2", + "--per_device_train_batch_size", + "1", + "--per_device_eval_batch_size", + "1", + "--eval_strategy", + "no", + "--logging_steps", + "1", + "--report_to", + "none", + "--dtype", + "float32", + "--bf16", + "False", + "--use_cpu", + "True", + ], + "gpt-oss", + env={**os.environ, "CUDA_VISIBLE_DEVICES": ""}, + ) + + assert output_dir.exists(), "SFT output directory should exist after training" + assert (output_dir / "config.json").exists(), "Saved model config.json should exist" + assert any(output_dir.glob("*.safetensors")), "Saved model safetensors should exist" diff --git a/tests/examples/llm_qat/test_assistant_mask.py b/tests/examples/llm_qat/test_assistant_mask.py new file mode 100644 index 00000000000..bbfb38b19aa --- /dev/null +++ b/tests/examples/llm_qat/test_assistant_mask.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit test for ChatML heuristic assistant masking using real Qwen3 tokenizer.""" + +import sys +from pathlib import Path + +import pytest +import transformers + +sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "examples" / "llm_qat")) + +from dataset_utils import _chatml_assistant_mask, _supports_chatml_heuristic + +CONVERSATION = [ + {"role": "user", "content": "Hello assistant"}, + {"role": "assistant", "content": "Hello user"}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I'm good"}, +] + + +@pytest.fixture(scope="module") +def tokenizer(): + return transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + +def test_chatml_assistant_mask(tokenizer): + assert _supports_chatml_heuristic(tokenizer) + + result = tokenizer.apply_chat_template( + CONVERSATION, + tokenize=True, + return_dict=True, + return_assistant_tokens_mask=True, + ) + input_ids = result["input_ids"] + heuristic = _chatml_assistant_mask(input_ids, tokenizer) + + assert len(heuristic) == len(input_ids) + assert sum(heuristic) > 0, "heuristic should mask some assistant tokens" + + tokens = [tokenizer.decode([tid]) for tid in input_ids] + for tok, m in zip(tokens, heuristic): + if m == 1: + assert tok not in ("<|im_start|>", "user", "system"), ( + f"non-content token {tok!r} should not be masked" + ) diff --git a/tests/examples/llm_qat/test_dataset_tokenization.py b/tests/examples/llm_qat/test_dataset_tokenization.py new file mode 100644 index 00000000000..bcea842e03f --- /dev/null +++ b/tests/examples/llm_qat/test_dataset_tokenization.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[3] / "examples" / "llm_qat")) + +from dataset_utils import IGNORE_TOKEN_ID, DatasetSourceConfig, make_chat_tokenize_fn + +CONVERSATION = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, +] + + +class FakeChatTokenizer: + pad_token_id = 0 + eos_token_id = 2 + unk_token_id = 999 + vocab_size = 1024 + + def __init__( + self, + *, + chat_template: str = "plain template", + chatml: bool = False, + input_ids: list[int] | None = None, + attention_mask: list[int] | None = None, + assistant_masks: list[int] | None = None, + name_or_path: str = "custom-model", + ): + self.chat_template = chat_template + self.chatml = chatml + self.input_ids = input_ids or [11, 22, 33, self.pad_token_id] + self.attention_mask = attention_mask or [1, 1, 1, 0] + self.assistant_masks = assistant_masks or [0, 1, 0, 0] + self.name_or_path = name_or_path + self.return_assistant_tokens_mask_calls: list[bool] = [] + + def convert_tokens_to_ids(self, token: str) -> int: + if not self.chatml: + return self.unk_token_id + return {"<|im_start|>": 100, "<|im_end|>": 101}.get(token, self.unk_token_id) + + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + del add_special_tokens + return {"assistant": [10], "\n": [11]}.get(text, [88]) + + def apply_chat_template( + self, + messages, + *, + tokenize: bool, + return_dict: bool, + return_assistant_tokens_mask: bool, + padding: str, + truncation: bool, + max_length: int, + ) -> dict[str, list[int]]: + del messages, tokenize, return_dict, padding, truncation, max_length + self.return_assistant_tokens_mask_calls.append(return_assistant_tokens_mask) + result = { + "input_ids": list(self.input_ids), + "attention_mask": list(self.attention_mask), + } + if return_assistant_tokens_mask: + result["assistant_masks"] = list(self.assistant_masks) + return result + + +def test_train_only_assistant_tokens_auto_uses_native_generation_tags(): + tokenizer = FakeChatTokenizer(chat_template="{% generation %}{{ content }}{% endgeneration %}") + + tokenized = make_chat_tokenize_fn(tokenizer, max_length=4)({"messages": CONVERSATION}) + + assert tokenizer.return_assistant_tokens_mask_calls == [True] + assert tokenized["labels"] == [IGNORE_TOKEN_ID, 22, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID] + + +def test_train_only_assistant_tokens_auto_uses_chatml_heuristic(): + tokenizer = FakeChatTokenizer( + chatml=True, + input_ids=[100, 10, 11, 42, 101, 0], + attention_mask=[1, 1, 1, 1, 1, 0], + name_or_path="Qwen/test-tokenizer", + ) + + tokenized = make_chat_tokenize_fn(tokenizer, max_length=6)({"messages": CONVERSATION}) + + assert tokenizer.return_assistant_tokens_mask_calls == [False] + assert tokenized["labels"] == [ + IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, + 42, + IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, + ] + + +def test_train_only_assistant_tokens_auto_falls_back_for_untested_chatml(): + tokenizer = FakeChatTokenizer( + chatml=True, + input_ids=[100, 10, 11, 42, 101, 0], + attention_mask=[1, 1, 1, 1, 1, 0], + ) + + tokenized = make_chat_tokenize_fn(tokenizer, max_length=6)({"messages": CONVERSATION}) + + assert tokenizer.return_assistant_tokens_mask_calls == [False] + assert tokenized["labels"] == [100, 10, 11, 42, 101, IGNORE_TOKEN_ID] + + +def test_train_only_assistant_tokens_auto_falls_back_to_full_chat_labels(): + tokenizer = FakeChatTokenizer() + + tokenized = make_chat_tokenize_fn(tokenizer, max_length=4)({"messages": CONVERSATION}) + + assert tokenizer.return_assistant_tokens_mask_calls == [False] + assert tokenized["labels"] == [11, 22, 33, IGNORE_TOKEN_ID] + + +def test_train_only_assistant_tokens_true_requires_supported_masking(): + tokenizer = FakeChatTokenizer() + + with pytest.raises(ValueError, match="train_only_assistant_tokens: false"): + make_chat_tokenize_fn(tokenizer, max_length=4, train_only_assistant_tokens=True) + + assert tokenizer.return_assistant_tokens_mask_calls == [] + + +def test_train_only_assistant_tokens_true_allows_explicit_chatml_heuristic(): + tokenizer = FakeChatTokenizer( + chatml=True, + input_ids=[100, 10, 11, 42, 101, 0], + attention_mask=[1, 1, 1, 1, 1, 0], + ) + + tokenized = make_chat_tokenize_fn(tokenizer, max_length=6, train_only_assistant_tokens=True)( + {"messages": CONVERSATION} + ) + + assert tokenizer.return_assistant_tokens_mask_calls == [False] + assert tokenized["labels"] == [ + IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, + 42, + IGNORE_TOKEN_ID, + IGNORE_TOKEN_ID, + ] + + +def test_train_only_assistant_tokens_false_uses_full_chat_labels(): + tokenizer = FakeChatTokenizer(chat_template="{% generation %}{{ content }}{% endgeneration %}") + + tokenized = make_chat_tokenize_fn(tokenizer, max_length=4, train_only_assistant_tokens=False)( + {"messages": CONVERSATION} + ) + + assert tokenizer.return_assistant_tokens_mask_calls == [False] + assert tokenized["labels"] == [11, 22, 33, IGNORE_TOKEN_ID] + + +def test_dataset_source_config_normalizes_train_only_assistant_tokens(): + source = DatasetSourceConfig( + hf_path="dataset", + ratio=1, + split="train", + train_only_assistant_tokens="false", + ) + + assert source.train_only_assistant_tokens is False + with pytest.raises(ValueError, match="train_only_assistant_tokens"): + DatasetSourceConfig( + hf_path="dataset", + ratio=1, + split="train", + train_only_assistant_tokens="unsupported", + ) diff --git a/tests/examples/llm_qat/test_llm_qat.py b/tests/examples/llm_qat/test_llm_qat.py index 5a0e7ad4428..d4c23e096a6 100644 --- a/tests/examples/llm_qat/test_llm_qat.py +++ b/tests/examples/llm_qat/test_llm_qat.py @@ -15,99 +15,167 @@ import pytest -import torch from _test_utils.examples.run_command import run_example_command -from _test_utils.torch.misc import minimum_sm + +# Mapping from backend name to accelerate config file +BACKEND_CONFIGS = { + "fsdp2": "configs/accelerate/fsdp2.yaml", + "ddp": "configs/accelerate/ddp.yaml", + "deepspeed": "configs/accelerate/deepspeed.yaml", +} + +# Backends that need gradient checkpointing +GRADIENT_CHECKPOINTING_BACKENDS = {"ddp", "deepspeed"} # fmt: off -def _run_command(extra_cmd_args: list[str]): +def _fast_data_args(cache_dir: str) -> list[str]: + """Fast dataset overrides for all tests (small samples, no shuffle, temp cache).""" + return [ + "--dataset_config", "configs/dataset/blend_test.yaml", + "--train_samples", "64", + "--eval_samples", "16", + "--shuffle", "False", + "--dataset_cache_dir", cache_dir, + ] + + +def _run_quantize(extra_cmd_args: list[str], cache_dir: str = ""): run_example_command( [ - "./launch.sh", - "--fsdp_transformer_layer_cls_to_wrap", "LlamaDecoderLayer", - "--num_epochs", "0.3", - "--lr", "1e-5", + "python", "quantize.py", + *_fast_data_args(cache_dir), + *extra_cmd_args, + ], + "llm_qat", + ) + + +def _run_train(extra_cmd_args: list[str], backend: str = "fsdp2", cache_dir: str = ""): + config_file = BACKEND_CONFIGS[backend] + gradient_args = ( + ["--gradient_checkpointing", "True"] + if backend in GRADIENT_CHECKPOINTING_BACKENDS + else [] + ) + run_example_command( + [ + "accelerate", "launch", + "--config-file", config_file, + "train.py", + *_fast_data_args(cache_dir), + "--num_train_epochs", "0.3", + "--learning_rate", "1e-5", + "--per_device_train_batch_size", "2", + "--per_device_eval_batch_size", "2", "--save_steps", "5", - "--calib_size", "64", - "--train_size", "256", - "--eval_size", "64", + "--eval_steps", "5", + *gradient_args, *extra_cmd_args, ], "llm_qat", setup_free_port=True, ) +def test_dataset_utils_pretokenize(tiny_qwen3_path, tmp_path): + """Test dataset_utils.py standalone CLI pre-tokenization.""" + cache_dir = tmp_path / "dataset_cache" + run_example_command( + [ + "python", "dataset_utils.py", + *_fast_data_args(str(cache_dir)), + "--model_name_or_path", tiny_qwen3_path, + ], + "llm_qat", + ) + assert cache_dir.exists(), "Cache directory should be created" + assert any(cache_dir.iterdir()), "Cache directory should contain tokenized data" + + @pytest.mark.parametrize("backend", [ - pytest.param("fsdp1", marks=pytest.mark.skipif(torch.cuda.device_count() < 2, reason="need 2 GPUs!")), "fsdp2", "deepspeed", "ddp", ]) -def test_llama_qat_int4w_int8a(tiny_llama_path, tmp_path, backend): +def test_qwen3_qat_nvfp4(tiny_qwen3_path, tmp_path, backend): ptq_output_dir = tmp_path / "ptq" qat_output_dir = tmp_path / "qat" + cache_dir = str(tmp_path / "dataset_cache") - # Run PTQ - _run_command( + # Step 1: Quantize + _run_quantize( [ - "--model", tiny_llama_path, - "--do_train", "False", - "--quant_cfg", "INT4_WEIGHT_INT8_ACTIVATIONS", - "--output_dir", ptq_output_dir, - "--backend", backend, - ] + "--model_name_or_path", tiny_qwen3_path, + "--recipe", "general/ptq/nvfp4_default-kv_fp8", + "--calib_size", "64", + "--output_dir", str(ptq_output_dir), + ], + cache_dir=cache_dir, ) - # Run QAT on PTQ checkpoint - _run_command( + # Step 2: QAT + _run_train( [ - "--model", ptq_output_dir, + "--model_name_or_path", str(ptq_output_dir), "--do_train", "True", - "--output_dir", qat_output_dir, - "--backend", backend, - ] + "--output_dir", str(qat_output_dir), + ], + backend=backend, + cache_dir=cache_dir, ) -@pytest.mark.parametrize("backend", [ - pytest.param("fsdp1", marks=pytest.mark.skipif(torch.cuda.device_count() < 2, reason="need 2 GPUs!")), - "fsdp2", - "deepspeed", - "ddp", -]) -def test_llama_qat_int4w_int8a_direct_qat(tiny_llama_path, tmp_path, backend): - # Run PTQ + QAT together - _run_command( +def test_qwen3_lora_qat_nvfp4(tiny_qwen3_path, tmp_path): + ptq_output_dir = tmp_path / "ptq" + cache_dir = str(tmp_path / "dataset_cache") + + # Step 1: Quantize + _run_quantize( [ - "--model", tiny_llama_path, - "--do_train", "True", - "--quant_cfg", "INT4_WEIGHT_INT8_ACTIVATIONS", - "--output_dir", tmp_path, - "--backend", backend, - ] + "--model_name_or_path", tiny_qwen3_path, + "--recipe", "general/ptq/nvfp4_default-kv_fp8", + "--calib_size", "64", + "--output_dir", str(ptq_output_dir), + ], + cache_dir=cache_dir, ) -def test_llama_lora_qat_nvfp4(tiny_llama_path, tmp_path): - _run_command( + # Step 2: LoRA QAT + _run_train( [ - "--model", tiny_llama_path, + "--model_name_or_path", str(ptq_output_dir), "--do_train", "True", - "--quant_cfg", "NVFP4_DEFAULT_CFG", "--lora", "True", - "--output_dir", tmp_path / "lora_qat", - "--backend", "fsdp2", - ] + "--output_dir", str(tmp_path / "lora_qat"), + ], + backend="fsdp2", + cache_dir=cache_dir, ) -@minimum_sm(90) -def test_llama_qlora_nvfp4(tiny_llama_path, tmp_path): - _run_command( + +def test_qwen3_qlora_nvfp4(tiny_qwen3_path, tmp_path): + ptq_output_dir = tmp_path / "ptq" + cache_dir = str(tmp_path / "dataset_cache") + + # Step 1: Quantize with compression for QLoRA + _run_quantize( [ - "--model", tiny_llama_path, + "--model_name_or_path", tiny_qwen3_path, + "--recipe", "general/ptq/nvfp4_default-kv_fp8", + "--calib_size", "64", + "--compress", "True", + "--output_dir", str(ptq_output_dir), + ], + cache_dir=cache_dir, + ) + + # Step 2: QLoRA training + _run_train( + [ + "--model_name_or_path", str(ptq_output_dir), "--do_train", "True", - "--quant_cfg", "NVFP4_DEFAULT_CFG", "--lora", "True", - "--compress", "True", - "--output_dir", tmp_path / "qlora", - ] + "--output_dir", str(tmp_path / "qlora"), + ], + backend="ddp", + cache_dir=cache_dir, ) diff --git a/tests/unit/torch/opt/plugins/test_modelopt_arg_parser.py b/tests/unit/torch/opt/plugins/test_modelopt_arg_parser.py new file mode 100644 index 00000000000..78a89a27ff9 --- /dev/null +++ b/tests/unit/torch/opt/plugins/test_modelopt_arg_parser.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ModelOptArgParser.""" + +from dataclasses import field +from pathlib import Path + +import pytest + +pytest.importorskip("transformers") + +from modelopt.torch.opt.plugins.transformers import ModelOptArgParser, ModelOptHFArguments + + +class _ModelArgs(ModelOptHFArguments): + model_name: str = field(default="test-model", metadata={"help": "The model name."}) + hidden_size: int = field(default=128, metadata={"help": "Hidden size."}) + + +class _TrainArgs(ModelOptHFArguments): + learning_rate: float = field(default=1e-4, metadata={"help": "Learning rate."}) + epochs: int = field(default=3, metadata={"help": "Number of epochs."}) + + +class TestModelOptArgParser: + """Tests for ModelOptArgParser --config and --generate_docs features.""" + + def test_cli_args_only(self): + parser = ModelOptArgParser((_ModelArgs, _TrainArgs)) + model_args, train_args = parser.parse_args_into_dataclasses( + args=["--model_name", "my-model", "--learning_rate", "0.01"] + ) + assert model_args.model_name == "my-model" + assert train_args.learning_rate == 0.01 + assert train_args.epochs == 3 # default + + def test_yaml_config(self, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text("model_name: yaml-model\nepochs: 10\n") + + parser = ModelOptArgParser((_ModelArgs, _TrainArgs)) + model_args, train_args = parser.parse_args_into_dataclasses( + args=["--config", str(config_file)] + ) + assert model_args.model_name == "yaml-model" + assert train_args.epochs == 10 + + def test_cli_overrides_yaml(self, tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text("model_name: yaml-model\nlearning_rate: 0.001\n") + + parser = ModelOptArgParser((_ModelArgs, _TrainArgs)) + model_args, train_args = parser.parse_args_into_dataclasses( + args=["--config", str(config_file), "--learning_rate", "0.01"] + ) + assert model_args.model_name == "yaml-model" # from yaml + assert train_args.learning_rate == 0.01 # CLI override + + def test_empty_yaml_config(self, tmp_path): + config_file = tmp_path / "empty.yaml" + config_file.write_text("") + + parser = ModelOptArgParser((_ModelArgs, _TrainArgs)) + model_args, train_args = parser.parse_args_into_dataclasses( + args=["--config", str(config_file)] + ) + assert model_args.model_name == "test-model" # defaults + assert train_args.epochs == 3 + + def test_generate_docs(self, tmp_path): + output_path = tmp_path / "ARGUMENTS.md" + parser = ModelOptArgParser((_ModelArgs, _TrainArgs)) + + with pytest.raises(SystemExit) as exc_info: + parser.parse_args_into_dataclasses(args=["--generate_docs", str(output_path)]) + assert exc_info.value.code == 0 + + content = output_path.read_text() + assert "## _ModelArgs" in content + assert "## _TrainArgs" in content + assert "--model_name" in content + assert "--learning_rate" in content + assert "--epochs" in content + + def test_generate_docs_default_path(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + parser = ModelOptArgParser((_ModelArgs, _TrainArgs)) + + with pytest.raises(SystemExit) as exc_info: + parser.parse_args_into_dataclasses(args=["--generate_docs"]) + assert exc_info.value.code == 0 + + content = Path("ARGUMENTS.md").read_text() + assert "# Argument Reference" in content + + def test_docs_table_format(self, tmp_path): + output_path = tmp_path / "ARGUMENTS.md" + parser = ModelOptArgParser((_ModelArgs, _TrainArgs)) + + with pytest.raises(SystemExit): + parser.parse_args_into_dataclasses(args=["--generate_docs", str(output_path)]) + + content = output_path.read_text() + # Check table headers + assert "| Argument | Type | Default | Description |" in content + # Check a specific row + assert "`--model_name`" in content + assert '`"test-model"`' in content + assert "The model name." in content