【大模型学习 | 量化】pytorch量化基础知识(1)

发布于:2025-06-27 ⋅ 阅读:(18) ⋅ 点赞:(0)

pytorch量化

[!note]

  • 官方定义:performing computations and storing tensors at lower bitwidths than floating point precision.
  • 支持INT8量化,可以降低4倍的模型大小以及显存需求,加速2-4倍的推理速度
  • 通俗理解:降低权重和激活值的精度(FP32→INT8),从而提高模型大小以及显存需求。

一、前置知识

1.1 算子融合

​ 将多个连续层的计算操作合并为单个复合算子,减少对内存的访问次数

e.g. 例如将Conv → BN → ReLU, 融合为ConvBnReLU

操作流程 内存访问次数 计算强度
未融合(3个算子) 6次
已融合(1个算子) 2次

​ NVIDA GPU:

// 未融合:多次启动核函数
conv_kernel<<<...>>>(input, weight, temp1);
bias_kernel<<<...>>>(temp1, bias, temp2);
relu_kernel<<<...>>>(temp2, output);

// 已融合:单核函数完成所有操作
fused_kernel<<<...>>>(input, weight, bias, output) {
    float val = conv2d(input, weight);
    val += bias;
    output = max(val, 0.0f);
}

二、量化知识

2.1 对称量化 & 非对称量化

⚙️ 区别

  • 对称量化(Symmetric Quantization)

X i n t = r o u n d ( X f l o a t s c a l e ) , s c a l e = m a x ( ∣ X ∣ ) 2 n − 1 − 1 X_{int}=round(\frac{X_{float}}{scale}), scale = \frac{max(|X|)}{2^{n-1}-1} Xint=round(scaleXfloat),scale=2n11max(X)

  • 非对称量化(Affine Quantization)

X i n t = r o u n d ( X f l o a t s c a l e ) + z e r o _ p o i n t , s c a l e = m a x x − m i n ) x 2 n − 1 X_{int}=round(\frac{X_{float}}{scale}) + zero\_point, scale = \frac{max_x-min_)x}{2^{n}-1} Xint=round(scaleXfloat)+zero_point,scale=2n1maxxmin)x

z e r o _ p o i n t = r o u n d ( − m i n ( x ) s c a l e ) zero\_point = round(\frac{-min(x)}{scale}) zero_point=round(scalemin(x))

特性 对称量化(Symmetric Quantization) 非对称量化(Affine Quantization)
零点位置 固定为0 动态计算(zero_point)
数值范围 [-127, 127] (int8) [0, 255] (uint8)
计算开销 更低(无需zero_point计算) 更高
精度损失 对偏斜分布敏感 更鲁棒,能更好处理数据分布偏斜的情况
典型应用 权重量化(正负均衡) 激活值量化
硬件支持 广泛支持(如GPU/TPU) 需要额外处理zero_point

🤖 工程实现角度:为什么 PTQ 常用非对称,QAT 用对称

模式 推荐默认 背后原因
PTQ 权重:对称 激活:非对称 因为激活是不可训练的静态量化,非对称能更好地适应非负分布
QAT 权重:对称 激活:对称(人为设定) 因为激活是可训练的,你可以通过训练让它“对称”起来,精度损失更可控
2.2 PTQ & QAT

[!note]

PTQ 是直接对训练后的模型参数进行量化,因此适合于快速部署;QAT是通过插入伪量化节点,在训练过程中模拟量化误差以达到更高的精度,因此需要重新训练。

⚙️ 区别

特性 PTQ(训练后量化) QAT(量化感知训练)
训练阶段 仅FP32训练 插入伪量化节点训练
反向传播 ❌ 不支持 ✅ 通过STE支持
精度损失 较大(尤其小模型) 通常更小
计算开销 低(仅需校准) 高(需完整训练)
典型用途 快速部署 高精度要求的场景

[!tip]

QAT伪量化节点

  • 作用:在训练时模拟量化的误差。在每一层训练时,权重、激活值依然是FT32,但在每一层的传播中,值被“量化再还原”,模拟了量化过程。
  • 由于量化过程有round函数,是不可微的,因此需要Straight-Through Estimator(STE)近似梯度的 FakeQuant 模块

三、Pytorch实现量化的三种方式

参考链接:Quantization — PyTorch 2.7 documentation

特性 Eager Mode QAT FX Graph QAT Export QAT
实现方式 动态图模式 符号化重写 编译器优化
控制流支持
算子融合 ❌(只能手动融合) ✅🌟
典型API prepare_qat prepare_fx export
Type 只支持module 支持 module & function 支持 module & function

[!note]

无论是PTQ 还是 QAT , 每一种实现方式都需要 prepare_fx 和 convert_fx

model_prepared = quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs)
model_quantized = quantize_fx.convert_fx(model_prepared)

🎯 核心功能:在模型的每一个 qconfig_mapping 指定的量化位置(如 Conv2d、Linear)处,插入对应的 observerfake_quant 节点。

📦 插入两类模块:

类型 对应 prepare 的用途 说明
Observer 用于 PTQ 统计 min/max 用来 校准计算 scale 和 zero_point
FakeQuantize 用于 QAT 模拟量化误差,保留梯度流动,支持训练
3.1 Eager Mode Quantization
import torch

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval for fusion to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
    [['conv', 'bn', 'relu']])

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())

# run the training loop (not shown)
training_loop(model_fp32_prepared)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
3.2 FX Graph Mode Quantization (maintenance)
import torch
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_fp = UserModel()

#
# post training dynamic/weight_only quantization
#

# we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# post training static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# quantization aware training for static quantization
#

model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
3.3 PyTorch 2 Export Quantization
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.export import export_for_training
from torch.ao.quantization.quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 10)

   def forward(self, x):
       return self.linear(x)

# initialize a floating point model
float_model = M().eval()

# define calibration function
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result should mostly stay the same
m = export_for_training(m, *example_inputs).module()
# we get a model with aten ops

# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)

# run calibration
# calibrate(m, sample_inference_data)
m = convert_pt2e(m)

网站公告

今日签到

点亮在社区的每一天
去签到