一、eva_vit.py
1.VisionTransformer
class VisionTransformer(nn.Module):
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
if self.use_checkpoint: #是否启用梯度检查点
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
input(batch_size, channel_num, height, width) -> pathc_embed(batch_size, num_patches, embed_dim)(其中num_patches = H/patch_size *W/patch_size, 注意:这里的channel_num变成了embed_size) -> 增加cls_token(一个可学习参数,类似于分类头,用于提取整图语义)(batch_size, num_patches+1, embed_dim) -> 结合位置嵌入(x = x + self.pos_embed)(形状不变) -> dropout(形状不变) -> blocks(x)(形状不变)
解释:
(1)梯度检查点(Gradient Checkpointing),就是torch.utils.checkpoint.checkpoint
这个函数:
目的:节省显存。
做法:在前向传播时不保存部分中间激活(即“舍弃缓存”),只保存输入;在反向传播时重新执行前向计算以计算梯度。
换句话说,前向时“放弃缓存”,反向时“重新计算”。
这样能显著减少显存消耗,但代价是反向传播时增加了计算开销。
(2)get_intermediate_layers函数(几乎和features_forward一样)
这是 Vision Transformer 中用来 获取中间层输出特征 的方法,也就是:
把每个 transformer block 之后的输出都记录下来,返回一个 list。
features = []
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
2.create_eva_vit_g
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
num_heads=1408//88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
cached_file = download_cached_file(
url, check_hash=False, progress=True
)
state_dict = torch.load(cached_file, map_location="cpu")
interpolate_pos_embed(model,state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
# print(incompatible_keys)
if precision == "fp16":
# model.to("cuda")
convert_weights_to_fp16(model)
return model
(1)几个概念
名称 | 作用 | 具体内容 | 举例 |
---|---|---|---|
meta_path | 模型结构和配置路径 | 模型参数配置文件 | config.json , model.yaml |
checkpoint | 模型权重文件(训练到某步保存的) | 模型参数权重 | pytorch_model.bin |
cached_file | 已缓存的下载文件(权重或配置等) | 缓存的权重、配置或词表等文件 | ~/.cache/huggingface/transformers/xxx |
map_location | PyTorch加载时设备映射参数 | 指定加载权重到哪个设备 | 'cpu' , 'cuda:0' |
二、Qformer.py
Q-Former 是一个中介结构,用来“将视觉特征转化为语言模型可理解的格式”。
1.BertEmbeddings
将输入token embedding为向量,包括位置+内容+query
2.BertSelfAttention和BertAttention
BertAttention
是BertSelfAttention+SelfOutput+prune_heads(移除注意力头)
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = (
self.self.attention_head_size * self.self.num_attention_heads
)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
return outputs
BertSelfAttenion
(1)实现了cross_attention+self_attention
#定义QKV
#判断是cross_attention还是self_attention
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
#来自于编码器
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
#来自于编码的隐藏状态
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
#来自于当前的隐藏状态
#对输入的hidden_states生成查询向量
mixed_query_layer = self.query(hidden_states)
#转为合适的计算形状
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
(2)计算积分,计算的时候会计算上distance_scores
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(-1, 1) #(seq_length, 1)
position_ids_r = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(1, -1) #(1, seq_length)
#position_ids_l和position_ids_r是两个张量,表示序列中的每个位置的左侧和右侧索引
distance = position_ids_l - position_ids_r
#广播机制,计算每对位置之间的相对距离
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1
)#+ self.max_position_embeddings确保相对位置值是正值
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype
) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
#torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 使用了 爱因斯坦求和约定,
# 它表示对 query_layer 和 positional_embedding 进行矩阵乘法。
#Q:[batch_size, num_heads, seq_len, head_dim]
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
relative_position_scores_key = torch.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding
)
attention_scores = (
attention_scores
+ relative_position_scores_query
+ relative_position_scores_key
)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
(3)添加mask、softmax、save、head_mask(用于选择性屏蔽某些注意力头)
关于head_mask和prune_heads
功能 | head_mask |
prune_heads() |
---|---|---|
作用对象 | 动态屏蔽某些注意力头(运行时) | 永久删除注意力头(结构上) |
使用时机 | 模型运行时动态控制 | 模型构建阶段永久裁剪 |
是否保留参数 | 保留参数,仅屏蔽输出 | 彻底删除参数 |
是否可逆 | 是(可以关闭遮罩) | 否(结构被修改) |
计算量是否减少 | 否,仍执行计算,只是丢弃结果 | 是,直接减少参数和计算量 |
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
#head_mask代表有选择的屏蔽某些注意力头
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
(4)输出上下文/上下文+注意力概率,添加past_key_value
outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
outputs = outputs + (past_key_value,)
3.BertSelfOutput和BertOutput(线性处理),BertIntermediate(非线性处理)
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
模块名称 | 所在位置 | Linear 输入 → 输出 | 功能说明 |
---|---|---|---|
BertSelfOutput |
Attention 输出部分 | hidden_size → hidden_size |
处理 attention 输出,加 residual 和 LayerNorm |
BertOutput |
FFN 输出部分 | intermediate_size → hidden_size |
把 FFN 的高维输出映射回原始维度,加 residual 和 LayerNorm |
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
#线性层
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]#非线性层
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
4.BertLayer
(1)自注意力+缓存逻辑
#自注意力+缓存逻辑
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]#attention_output的结果
outputs = self_attention_outputs[1:-1]#attention的结果
present_key_value = self_attention_outputs[-1]
(2)对query_length分段
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
(3)根据has_cross_attention判断是否要加cross_attention
if self.has_cross_attention:#就用 query 去 attend encoder(例如图像)
assert (
encoder_hidden_states is not None
), "encoder_hidden_states must be given for cross-attention layers"
#利用cross_attention抽取图像信息
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
outputs = (
outputs + cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights
(4)线性处理后输出
layer_output = apply_chunking_to_forward(
#将长序列 attention_output 拆成若干小段(chunk),每一段分别走一遍前馈层(MLP),最后拼接起来。
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,#chunk的大小
self.seq_len_dim,#需要被分块的维度
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
(5)不需要cross_attention的直接输出
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
5.BertEncoder
(1)多层BertLayer且初始化每一层的参数
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = (
() if output_attentions and self.config.add_cross_attention else None
)
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
(2)如果需要checkpoint则checkpoint+缓存
if getattr(self.config, "gradient_checkpointing", False) and self.training:
#Checkpoint 模式下:不保存中间输出,而是在反向传播时重新计算这些中间值,从而节省显存。
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(
*inputs, past_key_value, output_attentions, query_length
)
#这是因为 torch.utils.checkpoint.checkpoint() 要求:
#forward 函数必须只接收 Tensor 类型的位置参数(不能有 keyword arguments)。
#所以我们封装了一个自定义函数 custom_forward,将其他参数(如 past_key_value, output_attentions, query_length)通过闭包捕获,从而传进去。
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
(3)输出
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
6.BertPooler
MiniGPT4是以[CLS] token的向量作为整个句子的语义表达来进行NSP任务训练的,所以不像图像处理那样需要找出空间上的显著局部信息,所以不是用常见的max/avg pool,而是使用线性+非线性处理来增强表达能力。
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
7.BertPreditction和BertLMPredtction
#BertPredictionHeadTransform
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
#BertLMPredictionHead
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
8.四个模型
方便起见,后面类的主要依赖关系就加在类开头的注释里
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
#BertEmbedding+BertEncoder+BertPooler(继承自BertPreTrainedModel)
class BertModel(BertPreTrainedModel)...
#BertModel+BertOnlyMLMHead(继承自BertPreTrainedModel)
#用于自回归语言建模(causal language modeling),即从左到右预测下一个token,适合生成任务(decoder模式)。
class BertLMHeadModel(BertPreTrainedModel)...
#BertModel+BertOnlyMLMHead(继承自BertPreTrainedModel)
#用于掩码语言模型(Masked Language Modeling,MLM),BERT原始预训练任务,随机遮盖输入token,模型预测被遮盖的词。
class BertForMaskedLM(BertPreTrainedModel)...
三、modeling_llama.py
这个主要就是继承LlamaForCausalLMOrig来进行生成以及计算logits,代码很短,且易懂,就不再解释了。一些我看的时候的问题都在注释里。
#LlamaForCausalLM 实际是 MiniGPT-4 的语言解码器,负责根据文本输入和视觉上下文生成语言输出。
class LlamaForCausalLM(LlamaForCausalLMOrig):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,#可能是 Q-Former 生成的视觉 query embedding 和文本embedding拼接后的embedding,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, #是否输出中间的隐藏状态
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS, local_files_only=True)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER, local_files_only=True)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
#在 MiniGPT-4 里,self.model 是封装了融合视觉 query 的 LLaMA Transformer 解码器。
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]#为什么看起来1和2相同?因为1是hidden_state,2是past_key_values,但是输入的past_key_valus为空
#MiniGPT-4支持模型并行(pretraining_tp),对词表做切片,分布式计算 logits,保证大模型的显存利用。
if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1) #预测分数
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "none":
loss = loss.view(logits.size(0), -1).mean(1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)