MiniGPT4源码拆解——models

发布于:2025-07-11 ⋅ 阅读:(18) ⋅ 点赞:(0)

一、eva_vit.py

1.VisionTransformer

 class VisionTransformer(nn.Module):
    def forward_features(self, x):
        x = self.patch_embed(x)
        batch_size, seq_len, _ = x.size()

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        for blk in self.blocks:
            if self.use_checkpoint:  #是否启用梯度检查点
                x = checkpoint.checkpoint(blk, x, rel_pos_bias)
            else:
                x = blk(x, rel_pos_bias)
        return x

input(batch_size, channel_num, height, width) -> pathc_embed(batch_size, num_patches, embed_dim)(其中num_patches = H/patch_size *W/patch_size, 注意:这里的channel_num变成了embed_size) -> 增加cls_token(一个可学习参数,类似于分类头,用于提取整图语义)(batch_size, num_patches+1, embed_dim) -> 结合位置嵌入(x = x + self.pos_embed)(形状不变) -> dropout(形状不变) -> blocks(x)(形状不变)

解释:

(1)梯度检查点(Gradient Checkpointing),就是torch.utils.checkpoint.checkpoint这个函数:

  • 目的:节省显存。

  • 做法:在前向传播时不保存部分中间激活(即“舍弃缓存”),只保存输入;在反向传播时重新执行前向计算以计算梯度

  • 换句话说,前向时“放弃缓存”,反向时“重新计算”。

  • 这样能显著减少显存消耗,但代价是反向传播时增加了计算开销。

(2)get_intermediate_layers函数(几乎和features_forward一样)

这是 Vision Transformer 中用来 获取中间层输出特征 的方法,也就是:

把每个 transformer block 之后的输出都记录下来,返回一个 list。

        features = []
        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        for blk in self.blocks:
            x = blk(x, rel_pos_bias)
            features.append(x)

2.create_eva_vit_g

def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
    model = VisionTransformer(
        img_size=img_size,
        patch_size=14,
        use_mean_pooling=False,
        embed_dim=1408,
        depth=39,
        num_heads=1408//88,
        mlp_ratio=4.3637,
        qkv_bias=True,
        drop_path_rate=drop_path_rate,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        use_checkpoint=use_checkpoint,
    )  
    url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
    cached_file = download_cached_file(
        url, check_hash=False, progress=True
    )
    state_dict = torch.load(cached_file, map_location="cpu")
    interpolate_pos_embed(model,state_dict)
    
    incompatible_keys = model.load_state_dict(state_dict, strict=False)
#     print(incompatible_keys)
    
    if precision == "fp16":
#         model.to("cuda") 
        convert_weights_to_fp16(model)
    return model

(1)几个概念

名称 作用 具体内容 举例
meta_path 模型结构和配置路径 模型参数配置文件 config.json, model.yaml
checkpoint 模型权重文件(训练到某步保存的) 模型参数权重 pytorch_model.bin
cached_file 已缓存的下载文件(权重或配置等) 缓存的权重、配置或词表等文件 ~/.cache/huggingface/transformers/xxx
map_location PyTorch加载时设备映射参数 指定加载权重到哪个设备 'cpu', 'cuda:0'

 二、Qformer.py

Q-Former 是一个中介结构,用来“将视觉特征转化为语言模型可理解的格式”。

1.BertEmbeddings

将输入token embedding为向量,包括位置+内容+query

2.BertSelfAttention和BertAttention

  • BertAttention

是BertSelfAttention+SelfOutput+prune_heads(移除注意力头)

class BertAttention(nn.Module):
    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.self = BertSelfAttention(config, is_cross_attention)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads,
            self.self.num_attention_heads,
            self.self.attention_head_size,
            self.pruned_heads,
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = (
            self.self.attention_head_size * self.self.num_attention_heads
        )
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,) + self_outputs[
            1:
        ]  # add attentions if we output them
        return outputs

  • BertSelfAttenion

(1)实现了cross_attention+self_attention
#定义QKV
        #判断是cross_attention还是self_attention
        if is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
            #来自于编码器
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
            #来自于编码的隐藏状态
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            #来自于当前的隐藏状态

        #对输入的hidden_states生成查询向量
        mixed_query_layer = self.query(hidden_states)
         #转为合适的计算形状
        query_layer = self.transpose_for_scores(mixed_query_layer)

        past_key_value = (key_layer, value_layer)
(2)计算积分,计算的时候会计算上distance_scores
# Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if (
            self.position_embedding_type == "relative_key"
            or self.position_embedding_type == "relative_key_query"
        ):
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(
                seq_length, dtype=torch.long, device=hidden_states.device
            ).view(-1, 1) #(seq_length, 1)
            position_ids_r = torch.arange(
                seq_length, dtype=torch.long, device=hidden_states.device
            ).view(1, -1) #(1, seq_length)
            #position_ids_l和position_ids_r是两个张量,表示序列中的每个位置的左侧和右侧索引
            distance = position_ids_l - position_ids_r
            #广播机制,计算每对位置之间的相对距离
            positional_embedding = self.distance_embedding(
                distance + self.max_position_embeddings - 1
            )#+ self.max_position_embeddings确保相对位置值是正值
            positional_embedding = positional_embedding.to(
                dtype=query_layer.dtype
            )  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding
                )
                #torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 使用了 爱因斯坦求和约定,
                # 它表示对 query_layer 和 positional_embedding 进行矩阵乘法。
                #Q:[batch_size, num_heads, seq_len, head_dim]
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding
                )
                relative_position_scores_key = torch.einsum(
                    "bhrd,lrd->bhlr", key_layer, positional_embedding
                )
                attention_scores = (
                    attention_scores
                    + relative_position_scores_query
                    + relative_position_scores_key
                )

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
 (3)添加mask、softmax、save、head_mask(用于选择性屏蔽某些注意力头)
 关于head_mask和prune_heads
功能 head_mask prune_heads()
作用对象 动态屏蔽某些注意力头(运行时) 永久删除注意力头(结构上)
使用时机 模型运行时动态控制 模型构建阶段永久裁剪
是否保留参数 保留参数,仅屏蔽输出 彻底删除参数
是否可逆 是(可以关闭遮罩) 否(结构被修改)
计算量是否减少 否,仍执行计算,只是丢弃结果 是,直接减少参数和计算量
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        if is_cross_attention and self.save_attention:
            self.save_attention_map(attention_probs)
            attention_probs.register_hook(self.save_attn_gradients)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs_dropped = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            #head_mask代表有选择的屏蔽某些注意力头
            attention_probs_dropped = attention_probs_dropped * head_mask

        context_layer = torch.matmul(attention_probs_dropped, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
(4)输出上下文/上下文+注意力概率,添加past_key_value
        outputs = (
            (context_layer, attention_probs) if output_attentions else (context_layer,)
        )

        outputs = outputs + (past_key_value,)

3.BertSelfOutput和BertOutput(线性处理),BertIntermediate(非线性处理)

class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
模块名称 所在位置 Linear 输入 → 输出 功能说明
BertSelfOutput Attention 输出部分 hidden_size → hidden_size 处理 attention 输出,加 residual 和 LayerNorm
BertOutput FFN 输出部分 intermediate_size → hidden_size 把 FFN 的高维输出映射回原始维度,加 residual 和 LayerNorm
class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        #线性层
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]#非线性层
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

 4.BertLayer

(1)自注意力+缓存逻辑

#自注意力+缓存逻辑
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = (
            past_key_value[:2] if past_key_value is not None else None
        )
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]#attention_output的结果
        outputs = self_attention_outputs[1:-1]#attention的结果

        present_key_value = self_attention_outputs[-1]

(2)对query_length分段

        if query_length > 0:
            query_attention_output = attention_output[:, :query_length, :]

(3)根据has_cross_attention判断是否要加cross_attention

            if self.has_cross_attention:#就用 query 去 attend encoder(例如图像)
                assert (
                    encoder_hidden_states is not None
                ), "encoder_hidden_states must be given for cross-attention layers"
                #利用cross_attention抽取图像信息
                cross_attention_outputs = self.crossattention(
                    query_attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions=output_attentions,
                )
                query_attention_output = cross_attention_outputs[0]
                outputs = (
                    outputs + cross_attention_outputs[1:-1]
                )  # add cross attentions if we output attention weights

(4)线性处理后输出

layer_output = apply_chunking_to_forward(
                #将长序列 attention_output 拆成若干小段(chunk),每一段分别走一遍前馈层(MLP),最后拼接起来。
                self.feed_forward_chunk_query,
                self.chunk_size_feed_forward,#chunk的大小
                self.seq_len_dim,#需要被分块的维度
                query_attention_output,
            )
            if attention_output.shape[1] > query_length:
                layer_output_text = apply_chunking_to_forward(
                    self.feed_forward_chunk,
                    self.chunk_size_feed_forward,
                    self.seq_len_dim,
                    attention_output[:, query_length:, :],
                )
                layer_output = torch.cat([layer_output, layer_output_text], dim=1)

(5)不需要cross_attention的直接输出

        else:
            layer_output = apply_chunking_to_forward(
                self.feed_forward_chunk,
                self.chunk_size_feed_forward,
                self.seq_len_dim,
                attention_output,
            )
        outputs = (layer_output,) + outputs

        outputs = outputs + (present_key_value,)

        return outputs

5.BertEncoder

(1)多层BertLayer且初始化每一层的参数

all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = (
            () if output_attentions and self.config.add_cross_attention else None
        )

        next_decoder_cache = () if use_cache else None

        for i in range(self.config.num_hidden_layers):
            layer_module = self.layer[i]
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

(2)如果需要checkpoint则checkpoint+缓存

            if getattr(self.config, "gradient_checkpointing", False) and self.training:
            #Checkpoint 模式下:不保存中间输出,而是在反向传播时重新计算这些中间值,从而节省显存。
                if use_cache:
                    logger.warn(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(
                            *inputs, past_key_value, output_attentions, query_length
                        )
                    #这是因为 torch.utils.checkpoint.checkpoint() 要求:
                    #forward 函数必须只接收 Tensor 类型的位置参数(不能有 keyword arguments)。
                    #所以我们封装了一个自定义函数 custom_forward,将其他参数(如 past_key_value, output_attentions, query_length)通过闭包捕获,从而传进去。
                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                    query_length,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

(3)输出

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

6.BertPooler

MiniGPT4是以[CLS] token的向量作为整个句子的语义表达来进行NSP任务训练的,所以不像图像处理那样需要找出空间上的显著局部信息,所以不是用常见的max/avg pool,而是使用线性+非线性处理来增强表达能力。

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

7.BertPreditction和BertLMPredtction

#BertPredictionHeadTransform
class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states

#BertLMPredictionHead
class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores

8.四个模型

方便起见,后面类的主要依赖关系就加在类开头的注释里

class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "bert"
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


#BertEmbedding+BertEncoder+BertPooler(继承自BertPreTrainedModel)
class BertModel(BertPreTrainedModel)...


#BertModel+BertOnlyMLMHead(继承自BertPreTrainedModel)
#用于自回归语言建模(causal language modeling),即从左到右预测下一个token,适合生成任务(decoder模式)。
class BertLMHeadModel(BertPreTrainedModel)...



#BertModel+BertOnlyMLMHead(继承自BertPreTrainedModel)
#用于掩码语言模型(Masked Language Modeling,MLM),BERT原始预训练任务,随机遮盖输入token,模型预测被遮盖的词。
class BertForMaskedLM(BertPreTrainedModel)...

三、modeling_llama.py

这个主要就是继承LlamaForCausalLMOrig来进行生成以及计算logits,代码很短,且易懂,就不再解释了。一些我看的时候的问题都在注释里。

#LlamaForCausalLM 实际是 MiniGPT-4 的语言解码器,负责根据文本输入和视觉上下文生成语言输出。
class LlamaForCausalLM(LlamaForCausalLMOrig):

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,#可能是 Q-Former 生成的视觉 query embedding 和文本embedding拼接后的embedding,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None, #是否输出中间的隐藏状态
        return_dict: Optional[bool] = None,
        reduction: Optional[str] = "mean",
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS, local_files_only=True)
        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER, local_files_only=True)

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        #在 MiniGPT-4 里,self.model 是封装了融合视觉 query 的 LLaMA Transformer 解码器。
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]#为什么看起来1和2相同?因为1是hidden_state,2是past_key_values,但是输入的past_key_valus为空
        #MiniGPT-4支持模型并行(pretraining_tp),对词表做切片,分布式计算 logits,保证大模型的显存利用。
        if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1) #预测分数
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction=reduction)
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            if reduction == "none":
                loss = loss.view(logits.size(0), -1).mean(1)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


网站公告

今日签到

点亮在社区的每一天
去签到