pytorch量化
[!note]
- 官方定义:performing computations and storing tensors at lower bitwidths than floating point precision.
- 支持INT8量化,可以降低4倍的模型大小以及显存需求,加速2-4倍的推理速度
- 通俗理解:降低权重和激活值的精度(FP32→INT8),从而提高模型大小以及显存需求。
一、前置知识
1.1 算子融合
将多个连续层的计算操作合并为单个复合算子,减少对内存的访问次数
e.g. 例如将Conv → BN → ReLU, 融合为ConvBnReLU
操作流程 | 内存访问次数 | 计算强度 |
---|---|---|
未融合(3个算子) | 6次 | 低 |
已融合(1个算子) | 2次 | 高 |
NVIDA GPU:
// 未融合:多次启动核函数
conv_kernel<<<...>>>(input, weight, temp1);
bias_kernel<<<...>>>(temp1, bias, temp2);
relu_kernel<<<...>>>(temp2, output);
// 已融合:单核函数完成所有操作
fused_kernel<<<...>>>(input, weight, bias, output) {
float val = conv2d(input, weight);
val += bias;
output = max(val, 0.0f);
}
二、量化知识
2.1 对称量化 & 非对称量化
⚙️ 区别
- 对称量化(Symmetric Quantization)
X i n t = r o u n d ( X f l o a t s c a l e ) , s c a l e = m a x ( ∣ X ∣ ) 2 n − 1 − 1 X_{int}=round(\frac{X_{float}}{scale}), scale = \frac{max(|X|)}{2^{n-1}-1} Xint=round(scaleXfloat),scale=2n−1−1max(∣X∣)
- 非对称量化(Affine Quantization)
X i n t = r o u n d ( X f l o a t s c a l e ) + z e r o _ p o i n t , s c a l e = m a x x − m i n ) x 2 n − 1 X_{int}=round(\frac{X_{float}}{scale}) + zero\_point, scale = \frac{max_x-min_)x}{2^{n}-1} Xint=round(scaleXfloat)+zero_point,scale=2n−1maxx−min)x
z e r o _ p o i n t = r o u n d ( − m i n ( x ) s c a l e ) zero\_point = round(\frac{-min(x)}{scale}) zero_point=round(scale−min(x))
特性 | 对称量化(Symmetric Quantization) | 非对称量化(Affine Quantization) |
---|---|---|
零点位置 | 固定为0 | 动态计算(zero_point) |
数值范围 | [-127, 127] (int8) |
[0, 255] (uint8) |
计算开销 | 更低(无需zero_point计算) | 更高 |
精度损失 | 对偏斜分布敏感 | 更鲁棒,能更好处理数据分布偏斜的情况 |
典型应用 | 权重量化(正负均衡) | 激活值量化 |
硬件支持 | 广泛支持(如GPU/TPU) | 需要额外处理zero_point |
🤖 工程实现角度:为什么 PTQ 常用非对称,QAT 用对称?
模式 | 推荐默认 | 背后原因 |
---|---|---|
PTQ | 权重:对称 激活:非对称 | 因为激活是不可训练的静态量化,非对称能更好地适应非负分布 |
QAT | 权重:对称 激活:对称(人为设定) | 因为激活是可训练的,你可以通过训练让它“对称”起来,精度损失更可控 |
2.2 PTQ & QAT
[!note]
PTQ 是直接对训练后的模型参数进行量化,因此适合于快速部署;QAT是通过插入伪量化节点,在训练过程中模拟量化误差以达到更高的精度,因此需要重新训练。
⚙️ 区别
特性 | PTQ(训练后量化) | QAT(量化感知训练) |
---|---|---|
训练阶段 | 仅FP32训练 | 插入伪量化节点训练 |
反向传播 | ❌ 不支持 | ✅ 通过STE支持 |
精度损失 | 较大(尤其小模型) | 通常更小 |
计算开销 | 低(仅需校准) | 高(需完整训练) |
典型用途 | 快速部署 | 高精度要求的场景 |
[!tip]
QAT伪量化节点
- 作用:在训练时模拟量化的误差。在每一层训练时,权重、激活值依然是FT32,但在每一层的传播中,值被“量化再还原”,模拟了量化过程。
- 由于量化过程有round函数,是不可微的,因此需要
Straight-Through Estimator(STE)
近似梯度的 FakeQuant 模块
三、Pytorch实现量化的三种方式
特性 | Eager Mode QAT | FX Graph QAT | Export QAT |
---|---|---|---|
实现方式 | 动态图模式 | 符号化重写 | 编译器优化 |
控制流支持 | ✅ | ❌ | ✅ |
算子融合 | ❌(只能手动融合) | ✅ | ✅🌟 |
典型API | prepare_qat |
prepare_fx |
export |
Type | 只支持module |
支持 module & function |
支持 module & function |
[!note]
无论是PTQ 还是 QAT , 每一种实现方式都需要 prepare_fx 和 convert_fx
model_prepared = quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs) model_quantized = quantize_fx.convert_fx(model_prepared)
🎯 核心功能:在模型的每一个
qconfig_mapping
指定的量化位置(如 Conv2d、Linear)处,插入对应的observer
或fake_quant
节点。📦 插入两类模块:
类型 对应 prepare 的用途 说明 Observer
用于 PTQ 统计 min/max 用来 校准计算 scale 和 zero_point FakeQuantize
用于 QAT 模拟量化误差,保留梯度流动,支持训练
3.1 Eager Mode Quantization
import torch
# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
def __init__(self):
super().__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.dequant(x)
return x
# create a model instance
model_fp32 = M()
# model must be set to eval for fusion to work
model_fp32.eval()
# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
[['conv', 'bn', 'relu']])
# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())
# run the training loop (not shown)
training_loop(model_fp32_prepared)
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
3.2 FX Graph Mode Quantization (maintenance)
import torch
from torch.ao.quantization import (
get_default_qconfig_mapping,
get_default_qat_qconfig_mapping,
QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy
model_fp = UserModel()
#
# post training dynamic/weight_only quantization
#
# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# post training static quantization
#
model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# quantization aware training for static quantization
#
model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
3.3 PyTorch 2 Export Quantization
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.export import export_for_training
from torch.ao.quantization.quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
# initialize a floating point model
float_model = M().eval()
# define calibration function
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result should mostly stay the same
m = export_for_training(m, *example_inputs).module()
# we get a model with aten ops
# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)
# run calibration
# calibrate(m, sample_inference_data)
m = convert_pt2e(m)