This folder demonstrates how the Model Optimizer does PTQ quantization on an LLM and deploys the quantized LLM with TensorRT-LLM.
This document introduces:
- The scripts to quantize, convert and evaluate LLMs,
- The Python code and APIs to quantize and deploy the models.
If you are interested in quantization-aware training (QAT) for LLMs, please refer to the QAT README.
There are many quantization schemes supported in the example scripts:
-
The FP8 format is available on the Hopper and Ada GPUs with CUDA compute capability greater than or equal to 8.9.
-
The INT8 SmoothQuant, developed by MIT HAN Lab and NVIDIA, is designed to reduce both the GPU memory footprint and inference latency of LLM inference.
-
The INT4 AWQ is an INT4 weight only quantization and calibration method. INT4 AWQ is particularly effective for low batch inference where inference latency is dominated by weight loading time rather than the computation time itself. For low batch inference, INT4 AWQ could give lower latency than FP8/INT8 and lower accuracy degradation than INT8.
-
The W4A8 AWQ is an extension of the INT4 AWQ quantization that it also uses FP8 for activation for more speed up and acceleration.
The following scripts provide an all-in-one and step-by-step model quantization example for Llama-3, NeMo Nemotron, and Megatron-LM models. The quantization format and the number of GPUs will be supplied as inputs to these scripts. By default, we build the engine for the fp8 format and 1 GPU.
cd <this example folder>
For LLM models like Llama-3:
# Install model specific pip dependencies if needed
export HF_PATH=<the downloaded LLaMA checkpoint from the Hugging Face hub, or simply the model card>
scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|int8_sq|int4_awq|w4a8_awq] --tp [1|2|4|8]
If the Huggingface model calibration fails on a multi-GPU system due to mismatched tensor placement, please try setting CUDA_VISIBLE_DEVICES to a smaller number.
FP8 calibration over a large model with limited GPU memory is not recommended but possible with the accelerate package. Please tune the device_map setting in
example_utils.py
if needed for model loading and the calibration process can be slow.
Huggingface models trained with
modelopt.torch.speculative
can be used as regular Huggingface models in PTQ.
For NeMo models like nemotron:
NeMo PTQ requires the NeMo package installed. It's recommended to start from the NeMo containers like nvcr.io/nvidia/nemo:24.07
or latest nvcr.io/nvidia/nemo:dev
directly.
# Inside the NeMo container:
# Download the nemotron model from the Hugging Face.
export GPT_MODEL_FILE=Nemotron-3-8B-Base-4k.nemo
# Reinstall latest modelopt and build the extensions if not already done.
pip install -U "nvidia-modelopt[torch]" --extra-index-url https://pypi.nvidia.com
python -c "import modelopt.torch.quantization.extensions as ext; ext.precompile()"
scripts/nemo_example.sh --type gpt --model $GPT_MODEL_FILE --quant [fp8|int8_sq|int4_awq] --tp [1|2|4|8]
If the TensorRT-LLM version in the NeMo container is older than the supported version, please continue building TRT-LLM engine with the
docker.io/library/modelopt_examples:latest
container built in the ModelOpt docker build step. Additionally you would also need topip install megatron-core "nemo-toolkit[all]" --extra-index-url https://pypi.nvidia.com
to install required NeMo dependencies
Megatron-LM framework PTQ and TensorRT-LLM deployment examples are maintained in the Megatron-LM GitHub repo. Please refer to the examples here.
If GPU out-of-memory error is reported running the scripts, please try editing the scripts and reducing the max batch size of the TensorRT-LLM engine to save GPU memory.
The example scripts above also have an additional flag --tasks
, where the actual tasks run in the script can be customized. The allowed tasks are build,mmlu,humaneval,benchmark,lm_eval
specified in the script parser. The tasks combo can be specified with a comma-separated task list. Some tasks like mmlu, humaneval can take a long time to run. To run lm_eval tasks, please also specify the --lm_eval_tasks
flag with comma separated lm_eval tasks here.
Please refer to the Technical Details
section below about the stage executed inside the script and the outputs per stage.
Model | fp8 | int8_sq | int4_awq | w4a8_awq1 |
---|---|---|---|---|
GPT2 | Yes | Yes | No | No |
GPTJ | Yes | Yes | Yes | Yes |
LLAMA 2 | Yes | Yes | Yes | Yes |
LLAMA 3, 3.1 | Yes | No | Yes | No |
LLAMA 2 (Nemo) | Yes | Yes | Yes | Yes |
CodeLlama | Yes | Yes | Yes | No |
Mistral | Yes | Yes | Yes | No |
Mixtral 8x7B, 8x22B | Yes3 | No | Yes2 | No |
Snowflake Arctic2 | Yes | No | Yes | No |
Falcon 40B, 180B | Yes | Yes | Yes | Yes |
Falcon 7B | Yes | Yes | No | No |
MPT 7B, 30B | Yes | Yes | Yes | Yes |
Baichuan 1, 2 | Yes | Yes | Yes | Yes |
ChatGLM2, 3 6B | No | No | Yes | No |
Bloom | Yes | Yes | Yes | Yes |
Phi-1,2,3 | Yes | Yes | Yes | Yes4 |
Phi-3.5 MOE | Yes | No | No | No |
Nemotron 8B | Yes | No | Yes | No |
Gemma 2B, 7B | Yes | No | Yes | Yes |
Gemma 2 9B, 27B | Yes | No | Yes | No |
RecurrentGemma 2B | Yes | Yes | Yes | No |
StarCoder 2 | Yes | Yes | Yes | No |
QWen5 | Yes | Yes | Yes | Yes |
DBRX | Yes | No | No | No |
InternLM2 | Yes | No | Yes | Yes4 |
Exaone | Yes | Yes | Yes | Yes |
Minitron | Yes | Yes | Yes | Yes2 |
1.The w4a8_awq is an experimental quantization scheme that may result in a higher accuracy penalty. Only available on sm90 GPUs
2.For some models, there is only support for exporting quantized checkpoints.
3.Mixtral FP8 only available on sm90 GPUs
4.W4A8_AWQ is only available on some models but not all
5.For some models, KV cache quantization may result in a higher accuracy penalty.
The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying hf_ptq.py and disabling the KV cache quantization or using the QAT instead.
This feature is experimental.
Besides TensorRT-LLM, the Model Optimizer also supports deploying the FP8 quantized Hugging Face LLM using vLLM. Model Optimizer supports exporting a unified checkpoint1 that is compatible for deployment with vLLM2. The unified checkpoint format aligns with the original model built from the HF transformers library. A unified checkpoint can be exported using the following command:
# Quantize and export
scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --export_fmt hf
Then start the inference instance using vLLM in python, for example:
from vllm import LLM
llm_fp8 = LLM(model="<the exported model path>", quantization="modelopt")
print(llm_fp8.generate(["What's the age of the earth? "]))
1. Unified checkpoint export currently does not support sparsity, speculative decoding, KV cache quantization or AutoQuantize. 2. Exported checkpoint can be deployed on vLLM with changes from this unmerged PR. Users can deploy to vLLM directly once the PR has been merged.
Model | FP8 |
---|---|
LLAMA 2 | Yes |
LLAMA 3, 3.1 | Yes |
QWen2 | Yes |
Mixtral 8x7B | Yes |
CodeLlama | Yes |
AutoQuantize is a PTQ algorithm from ModelOpt which quantizes a model by searching for the best quantization format per-layer while meeting the performance constraint specified by the user. This way, AutoQuantize
enables to trade-off model accuracy for performance.
Currently AutoQuantize
supports only effective_bits
as the performance constraint (for both weight-only quantization and
weight & activation quantization). See
AutoQuantize documentation for more details.
AutoQuantize
can be performed for Huggingface LLM models like Llama-3 as shown below:
export HF_PATH=<the downloaded LLaMA checkpoint from the Hugging Face hub, or simply the model card>
# --effective_bits specifies the constraint for `AutoQuantize`
# --quant specifies the format to be searched for `AutoQuantize`
scripts/huggingface_example.sh --type llama --model $HF_PATH --quant w4a8_awq,fp8 --effective_bits 4.8 --tp [1|2|4|8] --calib_batch_size 4
The above example perform AutoQuantize
where the less quantization sensitive layers are quantized with w4a8_awq
(specified by --quant w4a8_awq
) and the more sensitive layers
are kept un-quantized such that the effective bits is 4.8 (specified by --effective_bits 4.8
).
The usage is similar for NeMo models to perform AutoQuantize
. Please refer to the earlier section on NeMo models for the full setup instructions.
# --effective_bits specifies the constraint for `AutoQuantize`
# --quant specifies the formats to be searched for `AutoQuantize`. Multiple formats can be searched over by passing them as comma separated values
scripts/nemo_example.sh --type gpt --model $GPT_MODEL_FILE --quant fp8,int4_awq --effective_bits 6.4 --tp [1|2|4|8]
hf_ptq.py
and nemo_ptq.py
will use the Model Optimizer to calibrate the PyTorch models, and generate a TensorRT-LLM checkpoint, saved as a json (for the model structure) and safetensors files (for the model weights) that TensorRT-LLM could parse.
Quantization requires running the model in original precision (fp16/bf16). Below is our recommended number of GPUs based on our testing for as an example:
Minimum number of GPUs | 24GB (4090, A5000, L40) | 48GB (A6000, L40s) | 80GB (A100, H100) |
---|---|---|---|
Llama2 7B | 1 | 1 | 1 |
Llama2 13B | 2 | 1 | 1 |
Llama2 70B | 8 | 4 | 2 |
Falcon 180B | Not supported | Not supported | 8 |
Before quantizing the model, hf_ptq.py
offers users the option to sparsify the weights of their models using a 2:4 pattern. This reduces the memory footprint during inference and can lead to performance improvements. The following is an example command to enable weight sparsity:
scripts/huggingface_example.sh --model $HF_PATH --quant [fp16|fp8|int8_sq] --tp [1|2|4|8] --sparsity sparsegpt
The accuracy loss due to post-training sparsification depends on the model and the downstream tasks. To best preserve accuracy, sparsegpt is recommended. However, users should be aware that there could still be a noticeable drop in accuracy. Read more about Post-training Sparsification and Sparsity Aware Training (SAT) in the Sparsity README.
The script modelopt_to_tensorrt_llm.py
constructs the TensorRT-LLM network and builds the TensorRT-LLM engine for deployment using the quantization outputs model config files from the previous step. The generated engine(s) will be saved as .engine file(s), ready for deployment.
A list of accuracy validation benchmarks are provided in the llm_eval directory. Right now MMLU, HumanEval and MTbench are supported in this example by specifying the --tasks
flag running the scripts mentioned above. For MTBench, the task only runs the answer generation stage. Please follow fastchat to get the evaluation judge score.
The benchmark_suite.py
script is used as a fast performance benchmark. For details, please refer to the TensorRT-LLM documentation
This example also covers the lm_evaluation_harness, MMLU and the human eval accuracy benchmarks, whose details can be found here. The supported lm_eval evaluation tasks are listed here
PTQ can be achieved with simple calibration on a small set of training or evaluation data (typically 128-512 samples) after converting a regular PyTorch model to a quantized model. The accuracy of PTQ is typically robust across different choices of calibration data, so we use cnn_dailymail
by default. Users can try other datasets by easily modifying the get_calib_dataloader
in example_utils.py.
import modelopt.torch.quantization as mtq
model = AutoModelForCausalLM.from_pretrained("...")
# Select the quantization config, for example, INT8 Smooth Quant
config = mtq.INT8_SMOOTHQUANT_CFG
# Prepare the calibration set and define a forward loop
def forward_loop(model):
for data in calib_set:
model(data)
# PTQ with in-place replacement to quantized modules
model = mtq.quantize(model, config, forward_loop)
After the model is quantized, the TensorRT-LLM checkpoint can be stored. The user can specify the inference time TP and PP size and the export API will organize the weights to fit the target GPUs.
The export API is
from modelopt.torch.export import export_tensorrt_llm_checkpoint
with torch.inference_mode():
export_tensorrt_llm_checkpoint(
model, # The quantized model.
decoder_type, # The type of the model, e.g gpt, gptj, or llama.
dtype, # The exported weights data type.
export_dir, # The directory where the exported files will be stored.
inference_tensor_parallel, # The number of GPUs used in the inference time tensor parallel.
inference_pipeline_parallel, # The number of GPUs used in the inference time pipeline parallel.
use_nfs_workspace, # If exporting in a multi-node setup, please specify a shared directory like NFS for cross-node communication.
)
After the TensorRT-LLM checkpoint export, you can use the trtllm-build
build command to build the engines from the exported checkpoints. Please check the TensorRT-LLM Build API documentation for reference.
In this example, we use modelopt_to_tensorrt_llm.py script as the easy-to-use wrapper over the TensorRT-LLM build API to generate the engines.