SmolVLM2: The Smollest Video Model Ever(五)

发布于:2025-04-21 ⋅ 阅读:(78) ⋅ 点赞:(0)

https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct

继续学习SmolLM

模型概述

SmolLM 是一系列小型语言模型,有三种规模:参数数量分别为 1.35 亿、3.6 亿和 17 亿。

这些模型在 SmolLM 语料库上进行训练,该语料库是经过精心整理的高质量教育及合成数据集合,专为训练大语言模型而设计。更多详细信息,请参阅我们的博客文章。

为构建 SmolLM-Instruct,我们在公开可用的数据集上对基础模型进行了微调。

变更日志

版本发布 描述
v0.1 SmolLM-Instruct 的首次发布。我们在 WebInstructSub 数据集的允许使用子集上进行微调,并结合了 StarCoder2-Self-OSS-Instruct。然后,对于 1.35 亿参数和 17 亿参数的模型,在 HelpSteer 上进行了一个周期的直接偏好优化(DPO);对于 3.6 亿参数的模型,则在 argilla/dpo-mix-7k 上进行了直接偏好优化。
v0.2 我们将微调数据组合更改为更适合小型模型的数据集。我们在由 llama3.1-70B 生成的包含 2000 个简单日常对话的新数据集(everyday-conversations-llama3.1-2k)、Magpie-Pro-300K-Filtered、StarCoder2-Self-OSS-Instruct 以及 OpenHermes-2.5 的一小部分子集上进行训练。
v0.2 版本的模型在紧扣主题以及对标准提示(如问候语和关于其作为人工智能助手角色的问题)做出恰当回应方面表现更出色。在 AlpacaEval 评估中,SmolLM-360M-Instruct(v0.2)相较于 SmolLM-360M-Instruct(v0.1)的胜率为 63.3%。你可以在此处找到详细信息。

你可以在 transformers 代码中通过指定 revision="v0.1" 来加载 v0.1 版本的模型:

model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct", revision="v0.1")

用法

本地应用

⚡ 对于本地应用,除了在这个集合中快速的浏览器演示之外(https://huggingface.co/collections/HuggingFaceTB/local-smollms-66c0f3b2a15b4eed7fb198d0),你还可以找到 MLC、GGUF 和 Transformers.js 格式的优化模型实现。

我们注意到,4 位量化会降低 1.35 亿参数和 3.6 亿参数模型的质量,因此对于 MLC,我们使用 q016 量化,对于 WebGPU 演示,则使用 ONNX/Transformers.js 检查点。我们还建议使用温度 0.2 和核采样参数 top-p 为 0.9。

Transformers

安装 transformers:

bash

pip install transformers

python

# pip install transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM-135M-Instruct"

device = "cuda" # 使用 GPU 时设置为 "cuda",使用 CPU 时设置为 "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# 对于多 GPU 环境,安装 accelerate 并使用 `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")`
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

messages = [{"role": "user", "content": "What is the capital of France."}]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
print(input_text)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=50, temperature=0.2, top_p=0.9, do_sample=True)
print(tokenizer.decode(outputs[0]))
使用 TRL 进行聊天

你也可以使用 TRL 命令行界面在终端中与模型进行聊天:

pip install trl
trl chat --model_name_or_path HuggingFaceTB/SmolLM-135M-Instruct --device cpu

局限性

此外,生成的内容可能并不总是在事实上准确、逻辑上一致,或者没有训练数据中存在的偏差。我们建议用户将其用作辅助工具,而不是作为确定的信息来源。我们发现,这些模型可以处理常识性知识问题、创意写作和基本的 Python 编程。但它们仅支持英语,并且在处理算术、编辑任务和复杂推理方面可能存在困难。有关这些模型能力的更多详细信息,请参阅我们的博客文章。

训练参数

我们使用对齐手册,在变更日志中提到的数据集上训练模型,v0.2 版本使用以下参数(其中大多数参数来自 Zephyr Gemma 的训练方案):

  • 训练 1 个周期
  • 学习率为 1e-3
  • 余弦退火学习率调度
  • 热身比例为 0.1
  • 全局批量大小为 262k 个词元

你可以在此处找到训练方案:https://github.com/huggingface/alignment-handbook/tree/smollm/recipes/smollm

引用

plaintext

@misc{allal2024SmolLM,
      title={SmolLM - blazingly fast and remarkably powerful}, 
      author={Loubna Ben Allal and Anton Lozhkov and Elie Bakouch and Leandro von Werra and Thomas Wolf},
      year={2024},
}

代码添加与更改

config.json

{
  "_name_or_path": "HuggingFaceTB/SmolLM-135M",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 576,
  "initializer_range": 0.02,
  "intermediate_size": 1536,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 9,
  "num_hidden_layers": 30,
  "num_key_value_heads": 3,
  "pad_token_id": 2,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.42.3",
  "use_cache": true,
  "vocab_size": 49152
}

添加model代码

import torch
from llmc.utils.registry_factory import MODEL_REGISTRY
from .base_model import BaseModel
from transformers import AutoConfig, SmolVLMForConditionalGeneration
from loguru import logger
from accelerate import Accelerator, DistributedType
from typing import Optional, Union
from transformers.models.llama.modeling_llama import LlamaRMSNorm
# from .smolvlm_model import SmolVLMAutoModelForCausalLM
from llmc.compression.quantization.module_utils import (
    _LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_,
    _TRANSFORMERS_LN_TYPES_, LlmcFp8Linear)

@MODEL_REGISTRY
class SmolVLM2(BaseModel):
    def __init__(self, config, device_map=None, use_cache=False):
        super().__init__(config, device_map, use_cache)
        self.vision_prefix = "model.vision_model"
        self.text_prefix = "model.text_model"

        self._init_modality_specific_params()
        # 添加兼容性属性
        self.linear_blocks = []  # 用于兼容旧式索引访问
        self.block_modality_map = {}  # 记录块所属模态

    def _init_modality_specific_params(self):
        """初始化多模态专用参数"""
        self.blocks = {
            "vision": [],
            "text": []
        }
        self.vision_embeds = []
        self.text_embeds = []
        self.block_name_prefix = {}
        self.pairs = {}

    def build_model(self):
        self.model_config = AutoConfig.from_pretrained(
            self.model_path,
            trust_remote_code=True,  # 必须启用
            model_type="smolvlm",  # 显式指定类型
            torch_dtype=torch.bfloat16  # 强制指定配置类型
        )
        # 使用自定义加载器
        self.model = SmolVLMForConditionalGeneration.from_pretrained(
            self.model_path,
            config=self.model_config,
            device_map=self.device_map,
            trust_remote_code=True,  # 关键参数
            torch_dtype=torch.bfloat16,  # 统一加载类型
            low_cpu_mem_usage=True,
        )
        # smol_VLMForConditionalGeneration=self.model
        # self.model=self.model.model
        # 修正lm_head数据类型
        if self.model.lm_head.weight.dtype != torch.bfloat16:
            self.model.lm_head = self.model.lm_head.to(torch.bfloat16)
        logger.info(f"lm_head dtype: {self.model.lm_head.weight.dtype}")
        # 初始化组件引用
        self.vision_model = self.model.model.vision_model
        self.text_model = self.model.model.text_model
        self.connector = self.model.model.connector
        # 验证类型一致性
        text_emb = self.text_model.embed_tokens
        assert text_emb.weight.dtype == torch.bfloat16, "文本嵌入层类型错误"
        assert self.model.lm_head.weight.dtype == torch.bfloat16, "输出头类型错误"


        # 统一设备初始化
        # self._sync_device()
    def find_blocks(self):
        # 文本模型的块(LlamaDecoderLayer)作为主要处理块
        self.blocks = self.text_model.layers
        # 视觉模型的块单独存储(可选,根据需求)
        self.vision_blocks = self.vision_model.encoder.layers

    def find_embed_layers(self):
        # 视觉嵌入层:patch embedding( Conv2d)和位置嵌入(Embedding)
        self.vision_patch_embed = self.vision_model.embeddings.patch_embedding
        self.vision_pos_embed = self.vision_model.embeddings.position_embedding
        # 文本嵌入层
        self.text_embed_tokens = self.text_model.embed_tokens

    def get_embed_layers(self):
        # 返回所有嵌入层(视觉和文本)
        return [self.vision_patch_embed, self.vision_pos_embed, self.text_embed_tokens]

    def get_head_layers(self):
        # 生成头
        return [self.model.lm_head]

    def get_pre_head_layernorm_layers(self):
        # 文本模型的最终层归一化
        return [self.text_model.norm]

    def get_layers_except_blocks(self):
        # 除块外的层:视觉嵌入、视觉后归一化、文本嵌入、文本最终归一化、生成头
        return [
            self.vision_patch_embed,
            self.vision_pos_embed,
            self.vision_model.post_layernorm,
            self.text_embed_tokens,
            self.text_model.norm,
            self.model.lm_head
        ]

    def skip_layer_name(self):
        # 跳过生成头(与原始LLaMA逻辑一致)
        return ['lm_head']

    def has_bias(self):
        # 视觉模块的线性层有偏置(q_proj/k_proj/v_proj/out_proj均bias=True),文本模块无偏置
        return True

    def get_layernorms_in_block(self, block):
        # 处理文本块的层归一化(与LLaMA一致)
        return {
            'input_layernorm': block.input_layernorm,
            'post_attention_layernorm': block.post_attention_layernorm,
        }

    def get_subsets_in_block(self, block):
        # 文本块的子集结构(与LLaMA一致)
        return [
            {
                'layers': {
                    'self_attn.q_proj': block.self_attn.q_proj,
                    'self_attn.k_proj': block.self_attn.k_proj,
                    'self_attn.v_proj': block.self_attn.v_proj,
                },
                'prev_op': [block.input_layernorm],
                'input': ['self_attn.q_proj'],
                'inspect': block.self_attn,
                'has_kwargs': True,
            },
            {
                'layers': {'self_attn.o_proj': block.self_attn.o_proj},
                'prev_op': [block.self_attn.v_proj],
                'input': ['self_attn.o_proj'],
                'inspect': block.self_attn.o_proj,
                'has_kwargs': False,
            },
            {
                'layers': {
                    'mlp.gate_proj': block.mlp.gate_proj,
                    'mlp.up_proj': block.mlp.up_proj,
                },
                'prev_op': [block.post_attention_layernorm],
                'input': ['mlp.gate_proj'],
                'inspect': block.mlp,
                'has_kwargs': False,
                'is_mlp': True,
            },
            {
                'layers': {'mlp.down_proj': block.mlp.down_proj},
                'prev_op': [block.mlp.up_proj],
                'input': ['mlp.down_proj'],
                'inspect': block.mlp.down_proj,
                'has_kwargs': False,
                'is_mlp': True,
            },
        ]

    # 以下为可选扩展(若需处理视觉块,可添加额外方法,但BaseModel未强制要求)
    def find_block_name(self):
        # 定义文本块的命名前缀(与LLaMA一致)
        self.block_name_prefix = 'text_model.layers'
        self.pairs = {'q_proj': 'qkv', 'o_proj': 'out', 'up_proj': 'fc1'}
    # 保持与BaseModel接口兼容的其他方法(如需可补充视觉处理逻辑)

配置新的SmolVLM2

from .bloom import Bloom
from .chatglm import ChatGLM
from .deepseekv2 import DeepseekV2
from .deepseekv3 import DeepseekV3
from .falcon import Falcon
from .gemma2 import Gemma2
from .glm4v import GLM4V
from .internlm2 import InternLM2
from .internomni import InternOmni
from .internvl2 import InternVL2
from .llama import Llama
from .llava import Llava
from .minicpm import MiniCPM
from .minicpmv import MiniCPMV
from .mistral import Mistral
from .mixtral import Mixtral
from .mllama import Mllama
from .opt import Opt
from .phi import Phi
from .phi3 import Phi3
from .qwen import Qwen
from .qwen2 import Qwen2
from .qwen2audio import Qwen2Audio
from .qwen2moe import Qwen2Moe
from .qwen2vl import Qwen2VL
from .smollm import SmolLM
from .smolvlm2 import SmolVLM2
from .stablelm import StableLm
from .starcoder import Starcoder
from .vila import Vila
from .vit import Vit

量化配置文件

base:
    seed: &seed 42
model:
    type: SmolVLM2 #【SmolLM,SmolVLM2】
    path: /mnt/share/toky/LLMs/SmolVLM2-2.2B-Instruct/ #【/mnt/share/toky/LLMs/SmolVLM2-2.2B-Instruct/,/mnt/share/toky/LLMs/SmolLM-135M-Instruct/】
    tokenizer_mode: slow
    torch_dtype: auto
calib:
    name: pileval
    download: False
    path: /mnt/share/toky/Datasets/LLMC/pileval/
    n_samples: 128
    bs: -1
    seq_len: 512
    preproc: pileval_awq
    seed: *seed
eval:
    eval_pos: [pretrain, transformed, fake_quant]
    name: wikitext2
    download: False
    path: /mnt/share/toky/Datasets/LLMC/wikitext2/
    seq_len: 2048
    # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
    # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
    bs: 1
    inference_per_block: False
quant:
    vision:
        method: Awq
        weight:
            bit: 4
            symmetric: True
            granularity: per_group
            group_size: 16
        special:
            trans: True
            # The options for "trans_version" include "v1" and "v2".
            # But their results don't differ significantly.
            trans_version: v2
            weight_clip: True
            # For 2-bit quantization, setting "clip_sym: False" will yield better results.
            clip_sym: True
    language:
        method: Awq
        weight:
            bit: 4
            symmetric: True
            granularity: per_group
            group_size: 128
        special:
            trans: True
            # The options for "trans_version" include "v1" and "v2".
            # But their results don't differ significantly.
            trans_version: v2
            weight_clip: True
            # For 2-bit quantization, setting "clip_sym: False" will yield better results.
            clip_sym: True
save:
    save_trans: False
    save_fake: False
    save_vllm: False
    save_path: /mnt/share/toky/Projects/LLMC_Test/llmc_quantized/SmolVLM2

修改了base_blockwise_quantization.py

import copy
import functools
import gc
import json
import os
import re
from collections import defaultdict
from functools import partial

import torch
import torch.distributed as dist
import torch.nn as nn
from loguru import logger

from llmc.utils.registry_factory import KV_REGISTRY, TOKEN_REDUCTION_REGISTRY

from ..blockwise_optimization import BlockwiseOpt
from .attn_utils import _LLMC_ATTN_MAP_
from .auto_clip import AutoClipper
from .utils import is_fp8_supported_gpu

if is_fp8_supported_gpu():
    from .kernel import weight_cast_to_bf16, weight_cast_to_fp8
    logger.info('import kernel successful.')
else:
    from .quant import weight_cast_to_bf16, weight_cast_to_fp8
    logger.info('import quant successful.')

from .hadamard_utils import apply_exact_had_to_linear, get_hadK
from .module_utils import (_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_,
                           _REALQUANT_LINEAR_MAP_, _TRANSFORMERS_LINEAR_TYPES_,
                           _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear,
                           FakeQuantLinear, LlmcActFn, OriginFloatLinear,
                           RotateLinear)
from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer
from .utils import check_do_quant, check_w_only, get_aquantizer, get_wquantizer


class BaseBlockwiseQuantization(BlockwiseOpt):
    def __init__(self, model, quant_config, input, padding_mask, config):
        super().__init__(model, quant_config, input, padding_mask, config)
        self.set_quant_config()

    def w_qdq(self, module, wquantizer):
        args = {'lowbound_factor': None, 'upbound_factor': None}
        if hasattr(module, 'buf_lowbound_factor'):
            args['lowbound_factor'] = module.buf_lowbound_factor
        if hasattr(module, 'buf_upbound_factor'):
            args['upbound_factor'] = module.buf_upbound_factor

        if module.weight.data.dtype == torch.float8_e4m3fn:
            tmp_weight \
                = weight_cast_to_bf16(module.weight,
                                      module.weight_scale_inv).to(torch.bfloat16)
        else:
            tmp_weight = module.weight

        tmp_weight = wquantizer.fake_quant_weight_dynamic(tmp_weight, args)

        if module.weight.data.dtype == torch.float8_e4m3fn:
            tmp_weight, module.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight)

        return tmp_weight

    def w_q(self, module, wquantizer):
        return wquantizer.real_quant_weight_dynamic(module.weight.data)

    def a_qdq(self, act, module, aquantizer, input_index=0):
        if self.act_static:
            args = {
                'scales': (getattr(module, f'buf_act_scales_{input_index}', None)),
                'zeros': (getattr(module, f'buf_act_zeros_{input_index}', None)),
                'qmax': (getattr(module, f'buf_act_qmax_{input_index}', None)),
                'qmin': (getattr(module, f'buf_act_qmin_{input_index}', None)),
            }
            return aquantizer.fake_quant_act_static(act, args)
        else:
            return aquantizer.fake_quant_act_dynamic(act)

    def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
        params_dict = {}
        if mode in ['fake_quant', 'fake_quant_wo_kv']:
            if not self.mix_bits:
                params_dict['a_qdq'] = (
                    partial(self.a_qdq, aquantizer=self.aquantizer)
                    if not w_only
                    else None
                )
                params_dict['w_qdq'] = partial(self.w_qdq, wquantizer=self.wquantizer)
            else:
                params_dict['mix_bits'] = True
                params_dict['a_qdq'] = self.a_qdq
                params_dict['w_qdq'] = self.w_qdq
                params_dict['mix_bits_map'] = self.mix_bits_map
                params_dict['quantizer_mix_bits'] = self.quantizer_mix_bits
                params_dict['wquantizer_default'] = self.wquantizer
                params_dict['aquantizer_default'] = self.aquantizer
                params_dict['w_only_default'] = w_only

        elif mode in _REALQUANT_LINEAR_MAP_.keys():
            params_dict['w_q'] = partial(self.w_q, wquantizer=self.wquantizer)
            params_dict['quant_config'] = self.quant_config

        elif mode == 'online_rotate':
            had_K, K = get_hadK(
                self.intermediate_size if 'down_proj' in name else self.num_heads
            )
            params_dict = {
                'had_K': had_K,
                'K': K,
                'online_full_had': 'down_proj' in name,
                'online_partial_had': 'o_proj' in name,
                'had_dim': (
                    None if 'down_proj' in name else self.hidden_size // self.num_heads
                ),
                'fp32_had': self.fp32_had,
            }

        elif mode == 'quant_attn':
            params_dict = {
                'matmul_a1_qdq': partial(
                    self.a_qdq, aquantizer=self.aquantizer, input_index=0
                ),
                'matmul_a2_qdq': partial(
                    self.a_qdq, aquantizer=self.aquantizer, input_index=1
                ),
                'softmax_a_qdq': (
                    partial(self.a_qdq, aquantizer=self.aquantizer)
                    if self.quant_softmax
                    else None
                ),
            }

        elif mode == 'quant_act_fn':
            params_dict = {'a_qdq': partial(self.a_qdq, aquantizer=self.aquantizer)}

        return params_dict

    def alloc_bits(self, mix_bits_settings):
        for i in range(len(mix_bits_settings)):
            mix_bits_setting = mix_bits_settings[f'setting_{i}']
            if mix_bits_setting['do_quant']:
                wquantizer_mix_bits = self.quant_module(**mix_bits_setting['weight'])
                if 'act' in mix_bits_setting:
                    w_only_mix_bits = False
                    aquantizer_mix_bits = self.quant_module(**mix_bits_setting['act'])
                else:
                    w_only_mix_bits = True
                self.quantizer_mix_bits.append(
                    {
                        'layer_name': mix_bits_setting['layer_name'],
                        'do_quant': mix_bits_setting['do_quant'],
                        'w_only_mix_bits': w_only_mix_bits,
                        'wquantizer': wquantizer_mix_bits,
                        'aquantizer': (
                            aquantizer_mix_bits if not w_only_mix_bits else None
                        ),
                    }
                )
            else:
                self.quantizer_mix_bits.append(
                    {
                        'layer_name': mix_bits_setting['layer_name'],
                        'do_quant': mix_bits_setting['do_quant'],
                    }
                )

        for i in range(len(self.quantizer_mix_bits)):
            logger.info(f'quantizer_mix_bits {i} : {self.quantizer_mix_bits[i]}')
            layer_name = self.quantizer_mix_bits[i]['layer_name']
            for name in layer_name:
                n_layeridx = name.split('#')
                assert (
                    len(n_layeridx) == 1 or len(n_layeridx) == 2
                ), 'layer_name in mix_bits must be name#1-3-4 or name.'
                if len(n_layeridx) == 2:
                    n = n_layeridx[0]
                    layeridx = n_layeridx[1].split('-')
                    layeridx = [int(idx) for idx in layeridx]
                else:
                    n = n_layeridx[0]
                    layeridx = 'all'
                if layeridx == 'all':
                    for k in range(self.num_blocks):
                        self.mix_bits_map[k][n] = i
                else:
                    for k in layeridx:
                        self.mix_bits_map[k][n] = i

    def set_quant_config(self):
        self.mix_bits = 'mix_bits' in self.quant_config
        self.mix_bits_map = [{} for _ in range(self.num_blocks)]
        self.quantizer_mix_bits = []

        if 'ignored_layers' in self.config:
            self.mixed_precision = True
            self.ignored_block_ids = self.config.ignored_layers.get('block_ids', [])
            self.ignored_layer_names = self.config.ignored_layers.get('layer_names', [])
            self.ignored_speical_names = self.config.ignored_layers.get('speical_names', [])
        else:
            self.mixed_precision = False

        self.quant_out = self.quant_config.get('quant_out', False)
        self.tp = self.quant_config.get('tp', 1)
        self.quant_config['weight']['tp'] = self.tp

        # select quantizer
        # weight
        quant_type = self.quant_config['weight'].get('quant_type', 'int-quant')
        if quant_type == 'int-quant':
            if self.quant_config['weight']['bit'] == 48:
                self.weight_quant_module = Weight48IntegerQuantizer
            else:
                self.weight_quant_module = IntegerQuantizer
        elif quant_type == 'float-quant':
            self.weight_quant_module = FloatQuantizer
        logger.info(f'The used Weight Quant Module is {self.weight_quant_module}')
        self.wquantizer = self.weight_quant_module(**self.quant_config['weight'])

        # act
        if 'act' in self.quant_config:
            if self.quant_config['weight']['granularity'] == 'per_block':
                assert self.quant_config['act']['granularity'] == 'per_group'
                assert self.quant_config['act']['group_size'] \
                    == self.quant_config['weight']['block_size']
            self.w_only = False
            quant_type = self.quant_config['act'].get('quant_type', 'int-quant')
            if quant_type == 'int-quant':
                if self.quant_config['act']['bit'] == 48:
                    self.act_quant_module = Weight48IntegerQuantizer
                else:
                    self.act_quant_module = IntegerQuantizer
            elif quant_type == 'float-quant':
                self.act_quant_module = FloatQuantizer
            self.quant_config['act']['tp'] = self.tp
            self.aquantizer = self.act_quant_module(**self.quant_config['act'])
            self.act_static = self.quant_config['act'].get('static', False)
            if self.act_static:
                assert (
                    self.quant_config['act']['granularity'] == 'per_tensor'
                ), 'Only support per_tensor static quant'
            self.quant_attn = self.quant_config['act'].get('quant_attn', False)
            if self.quant_attn:
                assert self.config['model']['type'] in ['Vit', 'DeepseekV2']
                self.quant_softmax = self.quant_config['act'].get(
                    'quant_softmax', False
                )
            self.quant_act_fn = self.quant_config['act'].get('quant_act_fn', False)
        else:
            self.w_only = True
            self.aquantizer = None
            self.act_static = False
            self.quant_attn = False
            self.quant_softmax = False
            self.quant_act_fn = False

        # set mix-bits quant config
        if self.mix_bits:
            mix_bits_settings = self.quant_config['mix_bits']
            logger.info(f'mix_bits_settings number: {len(mix_bits_settings)}')
            logger.info(
                f'mix_bits_settings:\n'
                f'{json.dumps(mix_bits_settings, ensure_ascii=False, indent=4)}'
            )
            self.alloc_bits(mix_bits_settings)

            logger.info(
                f'self.mix_bits_map:\n'
                f'{json.dumps(self.mix_bits_map, ensure_ascii=False, indent=4)}'
            )

        # set kv cache quant config
        if 'kvcache' in self.quant_config:
            self.quant_config['kvcache']['static'] = self.act_static
            kv_special_cfg = self.quant_config['kvcache'].get('special', {})
            act_static_cfg = {}
            if self.act_static:
                act_static_cfg.update(self.config.calib.n_sample)
                act_static_cfg.update(self.config.calib.bs)
            kv_quant_type = self.quant_config['kvcache'].get('quant_type', 'int-quant')
            self.kv_module = KV_REGISTRY[self.quant_config['kvcache']['method']](
                kv_quant_type, self.quant_config['kvcache'],
                self.model.model_config.text_config.num_hidden_layers, **kv_special_cfg, **act_static_cfg
            )
            self.quant_kvcache = True
            self.model.kvcache_buffer.append(self.kv_module)
        else:
            self.quant_kvcache = False

        # set special quant config
        special_config = self.quant_config.get('special', {})
        self.true_sequential = special_config.get('true_sequential', False)

        # set weight clip config
        self.weight_clip = special_config.get('weight_clip', False)
        if self.weight_clip or special_config.get('search_clip_init', False):
            self.save_clip = special_config.get('save_clip', False)
            if self.save_clip:
                self.clip_path = special_config['clip_path']
            self.clip_version = special_config.get('clip_version', 'v1')
            if self.clip_version == 'v2':
                assert self.wquantizer.calib_algo == 'learnable'
            clip_sym = special_config.get('clip_sym', self.wquantizer.sym)
            self.auto_clipper = AutoClipper(
                w_only=self.w_only,
                mix_bits_map=self.mix_bits_map,
                quantizer_mix_bits=self.quantizer_mix_bits,
                wquantizer=self.wquantizer,
                aquantizer=self.aquantizer,
                clip_version=self.clip_version,
                clip_sym=clip_sym,
                save_clip=self.save_clip,
                padding_mask=self.padding_mask,
            )

        # set transformation config
        self.save_scale = special_config.get('save_scale', False)
        if self.save_scale:
            self.scale_path = special_config['scale_path']
            self.act_scales = {}

        # set online-rotation config
        self.online_rotate = special_config.get('online_rotate', False)
        if self.online_rotate:
            assert (
                self.config['model']['type'] in ['Opt', 'Llama']
            ), 'Please set online_rotate=False'
            self.fp32_had = special_config.get('fp32_had', False)
        self.hidden_size = self.model.model_config.text_config.hidden_size
        self.set_model_config()
        self.modality = self.quant_config.modality
        logger.info(f'self.quant_objects : {self.quant_config.modality}')

        # set token reduction config
        if 'token_reduction' in self.quant_config:
            token_reduction_cfg = self.quant_config['token_reduction']
            TOKEN_REDUCTION_REGISTRY[self.quant_config['token_reduction']['method']](
                token_reduction_cfg, self.model, self.blocks
            )

        self.do_gqa_trans = special_config.get('do_gqa_trans', False)
        logger.info(f'self.do_gqa_trans : {self.do_gqa_trans}')

    def set_model_config(self):
        self.hidden_size = self.model.model_config.text_config.hidden_size
        self.num_heads = self.model.model_config.text_config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        if hasattr(self.model.model_config.text_config, 'intermediate_size'):
            self.intermediate_size = self.model.model_config.text_config.intermediate_size
        if hasattr(self.model.model_config.text_config, 'num_key_value_heads'):
            self.num_key_value_heads = self.model.model_config.text_config.num_key_value_heads
            self.num_key_value_groups = self.num_heads // self.num_key_value_heads
            if self.num_key_value_groups > 1:
                self.has_gqa = True
            else:
                self.has_gqa = False
        else:
            self.has_gqa = False

    def replace_rotate_linears(self, block):
        for n, m in block.named_modules():
            if isinstance(m, nn.Linear) and (
                'down_proj' in n or 'o_proj' in n or 'fc2' in n or 'out_proj' in n
            ):
                subset = {'layers': {n: m}}
                self.model.replace_module_subset(
                    RotateLinear,
                    block,
                    subset,
                    None,
                    self.get_replacement_params(
                        mode='online_rotate', w_only=self.w_only, name=n
                    ),
                )

    def replace_act_fn(self, block, extra_modules):
        act_fn_dict = self.model.get_act_fn_in_block(block)
        layers_dict = {'layers': act_fn_dict}
        self.model.replace_module_subset(
            LlmcActFn,
            block,
            layers_dict,
            self.block_idx,
            self.get_replacement_params(
                mode='quant_act_fn', w_only=self.w_only, name=None
            ),
        )
        extra_modules.update(act_fn_dict)

    def replace_attention(self, block, extra_modules):
        attn_layers_dict = self.model.get_attn_in_block(block)
        layers_dict = {'layers': attn_layers_dict}
        attn_module = _LLMC_ATTN_MAP_[self.config['model']['type']]
        self.model.replace_module_subset(
            attn_module,
            block,
            layers_dict,
            self.block_idx,
            self.get_replacement_params(
                mode='quant_attn', w_only=self.w_only, name=None
            ),
        )

        matmul_modules = self.model.get_matmul_in_block(block)
        softmax_modules = (
            self.model.get_softmax_in_block(block) if self.quant_softmax else {}
        )
        extra_modules.update(matmul_modules)
        extra_modules.update(softmax_modules)

    @torch.no_grad()
    def collect_block_qparams(self, block):
        named_linears = self.model.get_block_linears(block)
        for n, m in named_linears.items():
            args = {}
            if hasattr(m, 'buf_lowbound_factor'):
                args['lowbound_factor'] = m.buf_lowbound_factor
            if hasattr(m, 'buf_upbound_factor'):
                args['upbound_factor'] = m.buf_upbound_factor

            if m.weight.data.dtype == torch.float8_e4m3fn:
                tmp_weight_data = weight_cast_to_bf16(m.weight.data,
                                                      m.weight_scale_inv.data).to(torch.bfloat16)
            else:
                tmp_weight_data = m.weight.data

            (
                tensor,
                scales,
                zeros,
                max_int,
                min_int,
            ) = self.wquantizer.get_tensor_qparams(tmp_weight_data, args=args)

            m.register_buffer('buf_scales', scales.detach())
            m.register_buffer('buf_zeros', zeros.detach())
            m.register_buffer('buf_qmax', torch.tensor(max_int).to(self.dev))
            m.register_buffer('buf_qmin', torch.tensor(min_int).to(self.dev))

    def block_forward(self, block, input_data=None):
        output = []

        if input_data is None:
            input_data = self.input['data']

        for i in range(len(input_data)):
            input_data[i] = input_data[i].to(device=next(block.parameters()).device)
            for k in self.input['kwargs'][i]:
                if torch.is_tensor(self.input['kwargs'][i][k]):
                    self.input['kwargs'][i][k] = self.input['kwargs'][i][k].to(
                        device=next(block.parameters()).device
                    )  # noqa
                if isinstance(self.input['kwargs'][i][k], tuple):
                    self.input['kwargs'][i][k] = tuple(
                        tmp.to(device=next(block.parameters()).device)
                        for tmp in self.input['kwargs'][i][k]
                    )  # noqa
            with torch.no_grad():
                out = block(input_data[i], **self.input['kwargs'][i])
                if isinstance(out, tuple):
                    out = out[0]
                output.append(out)
        return output

    def block_opt(self, block):

        if self.quant_kvcache:
            self.register_kv_cache(block)

        block = block.cuda()
        named_linears = self.model.get_block_linears(block)
        extra_modules = self.model.get_extra_modules(block)

        if self.quant_attn:
            self.replace_attention(block, extra_modules)
        if self.quant_act_fn:
            self.replace_act_fn(block, extra_modules)

        input_feat_modules = {
            k: v for d in [named_linears, extra_modules] for k, v in d.items()
        }
        logger.info(f'input_feat_modules: {input_feat_modules}')
        input_feat = defaultdict(list)

        handles = self.register_hooks(input_feat_modules, input_feat)

        self.block_init(block)

        self.run(block, input_feat, handles)

        block = block.cpu()
        del input_feat, block
        gc.collect()
        torch.cuda.empty_cache()

    def register_hooks(self, input_feat_modules, input_feat):
        handles = []
        if not self.data_free:
            for name in input_feat_modules:
                handles.append(
                    input_feat_modules[name].register_forward_hook(
                        functools.partial(
                            self.cache_input_hook, name=name, feat_dict=input_feat
                        )
                    )
                )
        return handles

    def run(self, block, input_feat, handles):
        if not self.data_free:
            if self.quant_out:
                self.block_forward(block)
            else:
                self.input['data'] = self.block_forward(block)

            for h in handles:
                h.remove()
            torch.cuda.empty_cache()

            self.block_transform(block, input_feat, self.input['kwargs'])
        else:
            self.block_transform(block)

        if not self.data_free and self.quant_out:
            self.model.replace_module_block(
                FakeQuantLinear,
                block,
                self.block_idx,
                self.get_replacement_params(
                    mode='fake_quant', w_only=self.w_only, name=None
                ),
            )
            self.set_non_linear_mode('fake_quant', block, False)
            self.input['data'] = self.block_forward(block)
        torch.cuda.empty_cache()

    def block_transform(self, block, input_feat, block_kwargs):
        logger.info(f'Start transform the {self.block_idx}-th block')
        subsets = self.model.get_subsets_in_block(block)

        if self.act_static:
            self.register_non_linear_qparams(block, input_feat)

        self.set_non_linear_mode('fake_quant', block, False)

        for index, subset in enumerate(subsets):
            logger.info(f'subset: {subset}')
            layers_dict = subset['layers']
            input_name = subset['input'][0]
            inspect_has_kwargs = subset['has_kwargs']
            if inspect_has_kwargs:
                if 'sub_keys' in subset:
                    subset_kwargs = [
                        {k: block_kwargs[0][v] for k, v in subset['sub_keys'].items()}
                    ]
                else:
                    subset_kwargs = block_kwargs
            else:
                subset_kwargs = {}
            self.subset_transform(
                subset,
                input_feat,
                subset_kwargs,
            )
            if self.act_static:
                input_tensors = copy.deepcopy(input_feat[input_name])
                self.register_act_qparams(layers_dict, input_tensors)
                del input_tensors

            if self.true_sequential and index != len(subsets) - 1:
                next_subset = subsets[index + 1]
                input_feat_subset = self.rehook_next_subset(block, subset, next_subset)
                input_feat.update(input_feat_subset)

        self.set_non_linear_mode('fake_quant', block, True)
        logger.info(f'End transform the {self.block_idx}-th block')

    def rehook_next_subset(self, block, subset, next_subset):
        self.subset_init(next_subset)
        self.model.replace_module_subset(
            FakeQuantLinear,
            block,
            subset,
            self.block_idx,
            self.get_replacement_params(
                mode='fake_quant', w_only=self.w_only, name=None
            ),
        )

        input_feat_subset = defaultdict(list)
        input_feat_modules = next_subset['layers']
        handles = self.register_hooks(input_feat_modules, input_feat_subset)

        self.block_forward(block)
        for h in handles:
            h.remove()

        return input_feat_subset

    def collect_layers_weights(self, layers, tensor_parallelize_style=None):
        weights = []
        for _m in layers:
            if _m.weight.data.dtype == torch.float8_e4m3fn:
                fp8_scale = _m.weight_scale_inv
                tmp_weight = weight_cast_to_bf16(_m.weight, fp8_scale).to(torch.bfloat16)
                weights.append(tmp_weight)
            else:
                weights.append(_m.weight)
        return weights

    @torch.no_grad()
    def register_kv_cache(self, block):
        attn_layers_dict = self.model.get_attn_in_block(block)
        attn_layer = attn_layers_dict[list(attn_layers_dict.keys())[0]]
        setattr(attn_layer, 'kvcache', self.kv_module)
        attn_layer.register_forward_pre_hook(
            self.kv_cache_input_hook(attn_layer), with_kwargs=True
        )

    @torch.no_grad()
    def register_non_linear_qparams(self, block, input_feat):
        layer_types = [
            ('quant_attn', self.model.get_matmul_in_block),
            ('quant_softmax', self.model.get_softmax_in_block, 'quant_attn'),
            ('quant_act_fn', self.model.get_act_fn_in_block),
        ]

        for mode, layer_func, *dependency in layer_types:
            if getattr(self, mode, True) and all(
                getattr(self, dep, True) for dep in dependency
            ):
                layers_dict = layer_func(block)
                for name, layer in layers_dict.items():
                    input_tensors = copy.deepcopy(input_feat[name])
                    self.register_act_qparams({name: layer}, input_tensors)
                    del input_tensors

    @torch.no_grad()
    def register_act_qparams(self, layers_dict, act_tensors):
        scales_list, zeros_list, qmin_list, qmax_list = (
            self.aquantizer.get_batch_tensors_qparams(act_tensors)
        )
        world_size = int(os.environ['WORLD_SIZE'])

        for i, (scales, zeros, qmin, qmax) in enumerate(
            zip(scales_list, zeros_list, qmin_list, qmax_list)
        ):
            scales = scales.cuda()
            dist.all_reduce(scales, op=dist.ReduceOp.SUM)
            scales = scales / world_size

            for name, layer in layers_dict.items():
                if not isinstance(
                    layer, tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)
                ):
                    continue
                layer.register_buffer(f'buf_act_scales_{i}', scales)
                layer.register_buffer(f'buf_act_zeros_{i}', zeros.cuda())
                layer.register_buffer(f'buf_act_qmin_{i}', qmin.cuda())
                layer.register_buffer(f'buf_act_qmax_{i}', qmax.cuda())

    @torch.no_grad()
    def repeat_gqa_scales(self, scales):
        scales = scales.view(1, self.num_key_value_heads, self.head_dim)
        scales = torch.repeat_interleave(scales, dim=1, repeats=self.num_key_value_groups)
        return scales

    @torch.no_grad()
    def apply_scale(self, scales, prev_op, layers):
        assert (
            len(prev_op) == 1
        ), 'Only support single prev_op. If multi prev_ops, code need to be updated.'
        if isinstance(
            prev_op[0], tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)
        ):
            assert len(layers) == 1
            logger.info('apply scale between fc and fc')
            self.scale_fc_fc(prev_op[0], layers[0], scales)
        elif isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
            logger.info('apply scale between ln and fc')
            self.scale_ln_fcs(prev_op[0], layers, scales)
        else:
            raise NotImplementedError(f'prev_op {type(prev_op[0])} not supported yet!')

    @torch.no_grad()
    def apply_shift(self, shifts, prev_op, layers):
        if shifts is None:
            return

        assert (
            len(prev_op) == 1
        ), 'Only support single prev_op. If multi prev_ops, code need to be updated.'
        if isinstance(
            prev_op[0], tuple(_LLMC_LINEAR_TYPES_ + _TRANSFORMERS_LINEAR_TYPES_)
        ):
            assert len(layers) == 1
            self.shift_fc_fc(prev_op[0], layers[0], shifts)
        elif isinstance(prev_op[0], tuple(_LLMC_LN_TYPES_ + _TRANSFORMERS_LN_TYPES_)):
            self.shift_ln_fcs(prev_op[0], layers, shifts)
        else:
            raise NotImplementedError(f'prev_op {type(prev_op[0])} not supported yet!')

    @torch.no_grad()
    def scale_fc_fc(self, fc1, fc2, scales):
        scales = scales.to(fc1.weight.device)
        if fc1.out_features == fc2.in_features * 3:
            logger.info('fc1.out_features == fc2.in_features * 3')
            num_heads = self.model.get_num_attention_heads()
            fc1.weight.t_()
            org_shape = fc1.weight.shape
            fc1.weight.data = fc1.weight.data.reshape(org_shape[0] * num_heads, 3, -1)
            value = fc1.weight.data[:, 2, :].reshape(org_shape[0], -1)
            fc1.weight.data[:, 2, :] = value.div(scales.view(-1)).reshape(
                fc1.weight[:, 2, :].shape
            )
            fc1.weight.data = fc1.weight.data.reshape(org_shape).t_()
            if hasattr(fc1, 'bias') and fc1.bias is not None:
                fc1.bias.data = fc1.bias.data.reshape(num_heads, 3, -1)

                value = fc1.bias.data[:, 2, :].reshape(-1)

                fc1.bias.data[:, 2, :] = value.div(scales.view(-1)).reshape(
                    fc1.bias[:, 2, :].shape
                )
                fc1.bias.data = fc1.bias.data.reshape(-1)
        elif fc1.out_features == fc2.in_features * 2:
            logger.info('fc1.out_features == fc2.in_features * 2')
            fc1.weight.data[fc1.weight.data.shape[0] // 2:].div_(scales.view(-1, 1))
            if hasattr(fc1, 'bias') and fc1.bias is not None:
                fc1.bias.data[fc1.bias.data.shape[0] // 2:].div_(scales.view(-1))
        elif fc1.out_features == fc2.in_features:
            logger.info('fc1.out_features == fc2.in_features')
            assert fc1.out_features == fc2.in_features

            if hasattr(fc1, 'bias') and fc1.bias is not None:
                fc1.bias.div_(scales.view(-1))

            if fc1.weight.data.dtype == torch.float8_e4m3fn:
                fp8_scale = fc1.weight_scale_inv
                tmp_weight_data = weight_cast_to_bf16(fc1.weight.data, fp8_scale).to(torch.bfloat16)
                tmp_weight_data.div_(scales.view(-1, 1))

                fc1.weight.data, fc1.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight_data)
            else:
                fc1.weight.div_(scales.view(-1, 1))

        elif self.has_gqa and self.do_gqa_trans:
            if hasattr(fc1, 'bias') and fc1.bias is not None:
                fc1.bias.div_(scales.view(-1))
            fc1.weight.div_(scales.view(-1, 1))

            if fc1.out_features != fc2.in_features:
                logger.info('GQA scale this fc-fc.')
                scales = self.repeat_gqa_scales(scales)
        else:
            logger.error(f'fc1.out_features: {fc1.out_features}')
            logger.error(f'fc2.in_features: {fc2.in_features}')
            raise Exception('Can not scale this fc-fc.')

        if fc2.weight.data.dtype == torch.float8_e4m3fn:
            fp8_scale = fc2.weight_scale_inv
            tmp_weight_data = weight_cast_to_bf16(fc2.weight.data, fp8_scale).to(torch.bfloat16)
            tmp_weight_data.mul_(scales.view(1, -1))
            fc2.weight.data, fc2.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight_data)
        else:
            fc2.weight.mul_(scales.view(1, -1))

    @torch.no_grad()
    def shift_fc_fc(self, fc1, fc2, shifts):
        if fc1.out_features == fc2.in_features * 3:
            num_heads = self.model.get_model_config().to_dict().get('n_head', None)
            if hasattr(fc1, 'bias') and fc1.bias is not None:
                fc1.bias.data = fc1.bias.data.reshape(num_heads, 3, -1)

                value = fc1.bias.data[:, 2, :].reshape(-1)
                fc1.bias.data[:, 2, :] = (value - shifts).reshape(
                    fc1.bias[:, 2, :].shape
                )
                fc1.bias.data = fc1.bias.data.reshape(-1)
        else:
            assert fc1.out_features == fc2.in_features

            if hasattr(fc1, 'bias') and fc1.bias is not None:
                fc1.bias.sub_(shifts)

        if hasattr(fc2, 'bias') and fc2.bias is not None:
            fc2.bias.add_(fc2.weight @ shifts)
        else:
            if hasattr(self, 'use_shift') and self.use_shift:
                del fc2.bias
                fc2.register_buffer('bias', fc2.weight @ shifts)

    @torch.no_grad()
    def shift_ln_fcs(self, ln, fcs, shifts):
        if not isinstance(fcs, list):
            fcs = [fcs]

        if self.model.has_bias():
            ln.bias.sub_(shifts)

        for fc in fcs:
            if self.model.has_bias():
                fc.bias.add_(fc.weight @ shifts)
            else:
                if hasattr(self, 'use_shift') and self.use_shift:
                    del fc.bias
                    fc.register_buffer('bias', fc.weight @ shifts)

        for p in ln.parameters():
            assert torch.isnan(p).sum() == 0
        for fc in fcs:
            for p in fc.parameters():
                assert torch.isnan(p).sum() == 0

    @torch.no_grad()
    def scale_ln_fcs(self, ln, fcs, scales):
        if not isinstance(fcs, list):
            fcs = [fcs]
        scales = scales.to(ln.weight.device)
        ln.weight.div_(scales)

        if hasattr(ln, 'bias') and ln.bias is not None:
            ln.bias.div_(scales)

        for fc in fcs:
            if fc.weight.data.dtype == torch.float8_e4m3fn:
                fp8_scale = fc.weight_scale_inv.data
                tmp_weight_data = weight_cast_to_bf16(fc.weight.data, fp8_scale).to(torch.bfloat16)
                tmp_weight_data.mul_(scales.view(1, -1))
                fc.weight.data, fc.weight_scale_inv.data = weight_cast_to_fp8(tmp_weight_data)
            else:
                fc.weight.mul_(scales.view(1, -1))

        for p in ln.parameters():
            assert torch.isnan(p).sum() == 0
        for fc in fcs:
            for p in fc.parameters():
                assert torch.isnan(p).sum() == 0

    def rotate_pre_layers(self, pre_layers, Q):
        for layer in pre_layers:
            if layer.weight.data.dtype == torch.float8_e4m3fn:
                layer.weight.data \
                    = weight_cast_to_bf16(layer.weight.data,
                                          layer.weight_scale_inv.data).to(torch.bfloat16)
            dtype = layer.weight.dtype
            layer.weight.data = torch.matmul(layer.weight.data.double(), Q).to(dtype)

            if hasattr(layer, 'weight_scale_inv'):
                layer.weight.data, layer.weight_scale_inv.data \
                    = weight_cast_to_fp8(layer.weight.data)
            torch.cuda.empty_cache()

    def rotate_post_layers(self, post_layers, Q, exact_had=False):
        for layer in post_layers:
            if layer.weight.data.dtype == torch.float8_e4m3fn:
                layer.weight.data \
                    = weight_cast_to_bf16(layer.weight.data,
                                          layer.weight_scale_inv.data).to(torch.bfloat16)
            dtype = layer.weight.dtype
            layer.weight.data = torch.matmul(Q.T, layer.weight.data.double()).to(dtype)

            if exact_had and self.online_rotate:
                apply_exact_had_to_linear(layer, had_dim=-1, output=False)

            if hasattr(layer, 'bias') and layer.bias is not None:
                b = layer.bias.data.to(torch.float64)
                layer.bias.data = torch.matmul(Q.T, b).to(dtype)

            if hasattr(layer, 'weight_scale_inv'):
                layer.weight.data, layer.weight_scale_inv.data \
                    = weight_cast_to_fp8(layer.weight.data)
            torch.cuda.empty_cache()

    def rotate_embeddings(self, Q):
        embeddings = self.model.get_embed_layers()
        assert len(embeddings) == 1
        for layer in embeddings:
            dtype = layer.weight.data.dtype
            W = layer.weight.data.to(device=self.dev, dtype=torch.float64)
            layer.weight.data = torch.matmul(W, Q).to(device='cpu', dtype=dtype)

    def rotate_head(self, Q):
        heads = self.model.get_head_layers()
        for layer in heads:
            dtype = layer.weight.data.dtype
            W = layer.weight.data.to(device=self.dev, dtype=torch.float64)
            layer.weight.data = torch.matmul(W, Q).to(device='cpu', dtype=dtype)

    def fuse_ln_fcs(self, ln, fcs):
        for fc in fcs:
            if fc.weight.data.dtype == torch.float8_e4m3fn:
                fc.weight.data \
                    = weight_cast_to_bf16(fc.weight.data,
                                          fc.weight_scale_inv.data).to(torch.bfloat16)
            fc_dtype = fc.weight.dtype
            if hasattr(ln, 'bias') and ln.bias is not None:
                W = fc.weight.data.double().clone()
            fc.weight.data = (fc.weight.data.double() * ln.weight.double()).to(fc_dtype)
            if hasattr(ln, 'bias') and ln.bias is not None:
                if fc.bias is None:
                    fc.bias = torch.nn.Parameter(
                        torch.zeros(fc.out_features, dtype=torch.float64)
                    )
                fc.bias.data = fc.bias.data.double().to(device=W.device) + torch.matmul(
                    W, ln.bias.double()
                )
                fc.bias.data = fc.bias.data.to(fc_dtype)

            if hasattr(fc, 'weight_scale_inv'):
                fc.weight.data, fc.weight_scale_inv.data = weight_cast_to_fp8(fc.weight.data)
            torch.cuda.empty_cache()

    def remove_mean_from_embed(self):
        embeddings = self.model.get_embed_layers()
        for layer in embeddings:
            W = layer.weight.data.double()
            layer.weight.data = (W - W.mean(dim=-1, keepdim=True)).to(
                layer.weight.data.dtype
            )

    def bake_mean_into_fc(self, fc):
        fc_dtype = fc.weight.dtype
        W_ = fc.weight.data.double()
        fc.weight.data = W_ - W_.mean(dim=-2, keepdim=True)
        fc.weight.data = fc.weight.data.to(fc_dtype)
        if hasattr(fc, 'bias') and fc.bias is not None:
            b_ = fc.bias.data.double()
            fc.bias.data = b_ - b_.mean()
            fc.bias.data = fc.bias.data.to(fc_dtype)

    @torch.no_grad()
    def scaling_input(self, x, scales, is_gqa):
        if is_gqa:
            scales_tmp = self.repeat_gqa_scales(scales)
        else:
            scales_tmp = scales
        if hasattr(self, '_bs') and self._bs < x.shape[0]:
            x_tmp = torch.empty_like(x)
            for i, batch in enumerate(x):
                batch_scale = scales_tmp.view(1, -1)
                x_tmp[i] = batch / batch_scale
        else:
            x_tmp = x / scales_tmp.view(1, -1)
        return x_tmp

    @torch.no_grad()
    def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):
        for layer_name in layers_dict:
            for i in range(len(input_feat[layer_name])):
                inp = input_feat[layer_name][i]
                scale = scale.to(inp.device)
                input_feat[layer_name][i] = self.scaling_input(inp, scale, is_gqa)

    @torch.no_grad()
    def set_non_linear_mode(self, quant_format, module, mode):
        assert mode in [True, False]
        if quant_format != 'fake_quant':
            return
        for name, m in module.named_modules():
            if 'kvcache' in name:
                continue
            if getattr(m, 'calib', None) is not None:
                m.calib = mode

    def set_no_quant_layer(self):
        if self.ignored_speical_names:
            assert hasattr(self.model, 'block_name_prefix'), \
                'block_name_prefix missing in model'
        ignored_block_ids = []
        for item in self.ignored_block_ids:
            match = re.match(r'(\d+)-(\d+)', str(item))
            if match:
                start, end = int(match.group(1)), int(match.group(2))
                ignored_block_ids.extend(range(start, end + 1))
            else:
                ignored_block_ids.append(int(item))

        for idx, block in enumerate(self.blocks):
            for n, m in block.named_modules():
                if idx in ignored_block_ids and n in self.ignored_layer_names:
                    m.register_buffer('no_quant', torch.tensor(True))
                else:
                    layer_name = f'{self.model.block_name_prefix}.{idx}.{n}'
                    if layer_name in self.ignored_speical_names:
                        m.register_buffer('no_quant', torch.tensor(True))

    @torch.no_grad()
    def deploy(self, quant_format, keep_device=False):
        logger.info(f'-- deploy_{quant_format}_model start --')
        logger.info(f'quant_config : {self.quant_config}')

        module_mapping = {
            'origin_float': OriginFloatLinear,
            'fake_quant': EffcientFakeQuantLinear,
            'fake_quant_wo_kv': EffcientFakeQuantLinear,
        }
        module_mapping.update(_REALQUANT_LINEAR_MAP_)

        if quant_format not in module_mapping:
            raise NotImplementedError(
                f"Quant format '{quant_format}' is not implemented."
            )
        if self.mixed_precision and 'quant' in quant_format:
            self.set_no_quant_layer()

        module = module_mapping[quant_format]
        if self.modality == 'vision':
            self.model.replace_vision_module_all(
                module,
                self.get_replacement_params(mode=quant_format, w_only=self.w_only),
                keep_device=keep_device,
            )
        if self.modality == 'language':
            self.model.replace_language_module_all(
                module,
                self.get_replacement_params(mode=quant_format, w_only=self.w_only),
                keep_device=keep_device,
            )
        self.set_non_linear_mode(quant_format, self.model.model, False)

        if self.quant_kvcache:
            if quant_format == 'origin_float':
                self.kv_module.use_org_kv = True
            elif quant_format == 'fake_quant_wo_kv':
                self.kv_module.use_org_kv = True
            elif quant_format == 'fake_quant':
                self.kv_module.use_org_kv = False
                if self.act_static:
                    self.kv_module.calib = False

        if self.model.mm_model is not None:
            logger.info(f'Now, the mm_model is: {self.model.mm_model}')

        logger.info(f'-- deploy_{quant_format}_model done --')

    @torch.no_grad()
    def copy_tokenizer(self, path):
        self.model.tokenizer.save_pretrained(path)
        logger.info('copy tokenizer done --')

    @torch.no_grad()
    def contiguous_params(self):
        if self.model.mm_model is not None:
            for name, param in self.model.mm_model.named_parameters():
                if not param.is_contiguous():
                    param.data = param.data.contiguous()

            for name, param in self.model.mm_model.named_buffers():
                if not param.is_contiguous():
                    param.data = param.data.contiguous()
        else:
            for name, param in self.model.model.named_parameters():
                if not param.is_contiguous():
                    param.data = param.data.contiguous()

            for name, param in self.model.model.named_buffers():
                if not param.is_contiguous():
                    param.data = param.data.contiguous()

    @torch.no_grad()
    def save_model(self, path):
        if int(os.environ['RANK']) != 0:
            return
        self.contiguous_params()
        if self.config.model.type in ['Llava', 'InternVL2', 'Mllama', 'Qwen2vl']:
            self.model.vlm_model.language_model = self.model.get_model()
            self.model.vlm_model.save_pretrained(path)
            logger.info('save model done --')
            self.copy_tokenizer(path)
        elif self.config.model.type in ['Qwen2Audio']:
            self.model.alm_model.language_model = self.model.get_model()
            self.model.alm_model.save_pretrained(path)
            logger.info('save model done --')
            self.copy_tokenizer(path)
        elif self.config.model.type in ['InternOmni']:
            self.model.avlm_model.language_model = self.model.get_model()
            self.model.avlm_model.save_pretrained(path)
            logger.info('save model done --')
            self.copy_tokenizer(path)
        else:
            self.model.get_model().save_pretrained(path)
            logger.info('save model done --')
            self.copy_tokenizer(path)


网站公告

今日签到

点亮在社区的每一天
去签到