【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

发布于:2024-11-04 ⋅ 阅读:(62) ⋅ 点赞:(0)

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
在这里插入图片描述

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。
在这里插入图片描述

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架:BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。
    在这里插入图片描述

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标:CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。
    在这里插入图片描述

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益:BitDistiller需要的训练数据和资源更少,更具成本效益。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

代码

https://github.com/DD-DuDa/BitDistiller.git
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbF

class Round(Function):
    @staticmethod
    def forward(self, input):
        sign = torch.sign(input)
        output = sign * torch.floor(torch.abs(input) + 0.5)
        return output

    @staticmethod
    def backward(self, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
                           zero_point=True, q_group_size=-1,
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    elif q_group_size == -1:
        w = w.reshape(-1, w.shape[-1])
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2 ** n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = - 2 ** (n_bit - 1)
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w



@torch.no_grad()
def real_quantize_model_weight(
    model, w_bit, q_config,
    init_only=False
):
    from .qmodule import WQLinear
    from .pre_quant import get_blocks, get_named_linears, set_op_by_name
    assert q_config["zero_point"], "We only support zero_point quantization now."
    
    layers = get_blocks(model)
    for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
        layer = layers[i]
        named_linears = get_named_linears(layer)
        # scale_activations(layer)

        for name, module in named_linears.items():
            if init_only:
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], True)
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            else:
                module.cuda()
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                # scales = scales.t().contiguous()
                # zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], False, scales, zeros)
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
                
    torch.cuda.empty_cache()
    gc.collect()




def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):
    quantizer = SteN2F3Quantizer(q_group_size=q_group_size)
    w = quantizer(w)
    return w


class SteInt3AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 3
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, x.shape[-1])
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteInt2AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=64):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 2
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteN2F3Quantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
    
    def forward(self, x):
        org_w_shape = x.shape

        # reshape to groupsize
        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            qx = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            qx = x.reshape(-1, x.shape[-1])
        assert qx.dim() == 2

        # Get the Min Max
        max_val = qx.amax(dim=1, keepdim=True)
        min_val = qx.amin(dim=1, keepdim=True)

        
        scale_pos = torch.abs(max_val)
        scale_neg = torch.abs(min_val)

        dev = qx.device
        x_pos = torch.zeros_like(qx)
        x_neg = torch.zeros_like(qx)
        x_pos = torch.where(qx >= 0, qx, x_pos)
        x_neg = torch.where(qx < 0, qx, x_neg)
        q_pos = x_pos / scale_pos
        q_neg = x_neg / scale_neg

        q_pos, q_neg = self.round_pass(q_pos, q_neg, dev)

        qx = q_pos * scale_pos + q_neg * scale_neg

        qx = qx.reshape(org_w_shape)

        return qx
    
    def round_n2f3(self, q_pos, q_neg, dev):
        q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)
        q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)

        q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)
        q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)

        return q_pos, q_neg

    def round_pass(self, q_pos, q_neg, dev):
        y_grad_pos, y_grad_neg = q_pos, q_neg
        y_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)
        
        return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术