LogitsProcessor
是一个抽象基类,用于在生成序列的过程中对模型输出的logits进行处理。它的派生类实现了各种策略,以控制生成过程。
公共输入和输出
所有的LogitsProcessor
派生类都遵循相同的调用约定,即实现了__call__
方法,接受以下输入并返回处理后的logits:
输入
input_ids
(torch.LongTensor
): 形状为(batch_size, sequence_length)
的张量,表示当前已生成的序列的token索引。scores
(torch.FloatTensor
): 形状为(batch_size, vocab_size)
的张量,表示模型在当前时间步对所有词汇表中每个token的预测得分(通常是logits)。
输出
scores
(torch.FloatTensor
): 形状为(batch_size, vocab_size)
的张量,表示经过处理后的预测得分。这些得分随后将用于采样下一个token或选择概率最高的token。
类名/方法名 | 简要说明 |
---|---|
类 | 说明 |
LogitsProcessor |
所有在生成过程中可以应用的logits处理器的抽象基类。 |
LogitsProcessorList |
用于创建一个包含多个LogitsProcessor 的列表,以便对scores 张量进行依次处理。 |
MinLengthLogitsProcessor |
强制生成的序列至少达到最小长度,方法是在达到最小长度之前将EOS(结束)token的概率设置为负无穷大。 |
MinNewTokensLengthLogitsProcessor |
强制新生成的token序列至少达到最小长度,不包括提示(prompt)部分。 |
TemperatureLogitsWarper |
通过应用温度缩放来调节预测中token的概率分布,从而控制生成文本的随机性。 |
RepetitionPenaltyLogitsProcessor |
对已生成的token施加惩罚,防止模型重复生成相同的token序列。 |
EncoderRepetitionPenaltyLogitsProcessor |
对输入(encoder)的token施加惩罚或奖励,鼓励模型重复或避免重复输入中的内容。 |
TopPLogitsWarper |
实现核采样(top-p sampling),仅保留累积概率达到top_p 阈值的最可能的token。 |
TopKLogitsWarper |
实现top-k采样,仅保留概率最高的k 个token。 |
MinPLogitsWarper |
保留概率高于给定最小值min_p 的token,滤除低概率的token。 |
TypicalLogitsWarper |
实现典型采样,优先选择熵较低的token,避免选择过于常见或罕见的词汇。 |
EpsilonLogitsWarper |
实现epsilon采样,保留概率大于epsilon 的token,确保选出的token概率超过一定阈值。 |
EtaLogitsWarper |
基于熵动态调整eta 阈值,对低概率的token进行过滤,适应性地选择token。 |
NoRepeatNGramLogitsProcessor |
防止生成重复的n-gram序列,通过禁止已生成的n-gram再次出现。 |
EncoderNoRepeatNGramLogitsProcessor |
类似于NoRepeatNGramLogitsProcessor ,但针对的是输入(encoder)中已存在的n-gram,防止在生成中重复输入的n-gram。 |
SequenceBiasLogitsProcessor |
对特定的token序列施加加性偏置,可以增加或减少这些序列在生成中的概率。 |
NoBadWordsLogitsProcessor |
禁止生成包含特定“不良”词汇的序列,设置这些词汇的概率为负无穷大。 |
PrefixConstrainedLogitsProcessor |
在生成过程中,强制遵循指定的前缀token列表,限制生成的可能性,以满足特定的前缀约束。 |
HammingDiversityLogitsProcessor |
在beam search中引入哈夫曼多样性惩罚,鼓励生成的序列之间的多样性,避免不同beam生成相似的序列。 |
ForcedBOSTokenLogitsProcessor |
强制生成的序列以指定的开始(BOS)token开头,通常用于encoder-decoder模型。 |
ForcedEOSTokenLogitsProcessor |
在达到最大长度时,强制生成指定的结束(EOS)token,确保序列的终止。 |
InfNanRemoveLogitsProcessor |
移除logits中的无穷大(inf)和非数字(nan)值,避免生成过程中的数值错误。 |
ExponentialDecayLengthPenalty |
对EOS token的得分进行指数衰减,鼓励模型在适当的长度范围内结束生成,避免过长的输出。 |
LogitNormalization |
对logits进行log-softmax归一化,确保概率分布的正确性,常用于beam search后。 |
SuppressTokensAtBeginLogitsProcessor |
在生成开始时抑制特定的token,防止它们在序列的开头被生成。 |
SuppressTokensLogitsProcessor |
全局抑制指定的token,设置它们的logits为负无穷大,防止在任何位置生成。 |
WhisperTimeStampLogitsProcessor |
修改时间戳token的logits,用于控制生成时间戳的行为,常用于Whisper模型的转录任务。 |
WhisperNoSpeechDetection |
检测无语音段落,在检测到静音时,调整logits以反映无语音的状态。 |
ClassifierFreeGuidanceLogitsProcessor |
实现分类器自由引导(CFG),结合有条件和无条件的logits,以控制生成文本的倾向性。 |
AlternatingCodebooksLogitsProcessor |
在Bark模型的fine子模型中交替生成两个codebook,控制生成过程中的模式。 |
UnbatchedClassifierFreeGuidanceLogitsProcessor |
未批量化的CFG logits处理器,适用于非批量处理的模型。 |
BarkEosPrioritizerLogitsProcessor |
对于Bark模型,在满足特定条件时优先选择EOS token,确保序列适时结束。 |
WatermarkLogitsProcessor |
在生成的文本中嵌入水印,通过对特定的“绿色”token添加偏置,实现不可察觉的标记。 |
SynthIDTextWatermarkState |
SynthID文本水印的状态类,用于跟踪已生成的token序列和水印相关的信息。 |
SynthIDTextWatermarkLogitsProcessor |
实现SynthID文本水印的logits处理器,在生成过程中微调token的概率以嵌入水印,辅助检测生成内容的真实性。 |
全局方法 | 说明 |
---|---|
_get_ngrams(ngram_size, prev_input_ids, num_hypos) |
获取已生成的n-gram序列,用于防止重复生成。 |
_get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len) |
确定当前生成位置应被禁止的token列表,基于之前生成的n-gram。 |
_calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len) |
计算每个假设(hypothesis)中应被禁止的n-gram token,防止重复。 |
上述类和方法共同构成了生成文本时调整和控制模型输出的机制,帮助生成更符合预期的、流畅的文本。
派生类详细说明
1. MinLengthLogitsProcessor
功能:在生成序列未达到最小长度时,禁止生成结束符(EOS token),以确保生成的序列至少达到指定的最小长度。
输入参数
min_length
(int
): 最小生成长度。如果当前生成的序列长度小于min_length
,则禁止生成EOS token。eos_token_id
(Union[int, List[int], torch.Tensor]
): EOS token的ID或ID列表。
处理逻辑
- 检查
input_ids
的长度cur_len
是否小于min_length
。 - 如果是,则将
scores
中对应于eos_token_id
的位置设置为负无穷大-inf
,以禁止模型在此位置生成EOS token。
- 检查
2. MinNewTokensLengthLogitsProcessor
功能:与
MinLengthLogitsProcessor
类似,但只考虑新生成的tokens的长度,不包括初始的提示(prompt)部分。输入参数
prompt_length_to_skip
(int
): 提示部分的长度,在计算新生成的token长度时跳过此部分。min_new_tokens
(int
): 新生成的tokens的最小长度。eos_token_id
(Union[int, List[int], torch.Tensor]
): EOS token的ID或ID列表。
处理逻辑
- 计算新生成的tokens的长度:
new_tokens_length = len(input_ids) - prompt_length_to_skip
。 - 如果
new_tokens_length
小于min_new_tokens
,则将scores
中对应于eos_token_id
的位置设置为-inf
,以禁止生成EOS token。
- 计算新生成的tokens的长度:
3. TemperatureLogitsWarper
功能:通过温度缩放调节预测分布的平坦程度,影响生成过程的随机性。温度较高时(>1),分布更平坦,随机性更大;温度较低时(<1),分布更尖锐,更倾向于高概率的tokens。
输入参数
temperature
(float
): 温度值,需为正数。调整logits的平坦程度。
处理逻辑
- 将
scores
除以temperature
:scores = scores / temperature
。 - 温度高于1时,
scores
较小,softmax后分布更平坦;温度低于1时,scores
较大,softmax后分布更尖锐。
- 将
4. RepetitionPenaltyLogitsProcessor
功能:对已生成的tokens施加惩罚,防止重复生成相同的tokens。对于已经生成的tokens,如果它们再次出现,其对应的scores会被调整。
输入参数
penalty
(float
): 惩罚系数。大于1表示惩罚重复的tokens,介于0到1之间表示鼓励重复。
处理逻辑
- 对于每个batch,获取已生成的tokens
input_ids
在当前scores
中的得分。 - 根据得分的正负,对应地除以或乘以
penalty
:- 如果得分
<0
,则乘以penalty
(降低负得分的绝对值)。 - 如果得分
>0
,则除以penalty
(降低正得分)。
- 如果得分
- 更新
scores
中对应于已生成tokens的位置。
- 对于每个batch,获取已生成的tokens
5. EncoderRepetitionPenaltyLogitsProcessor
功能:与
RepetitionPenaltyLogitsProcessor
类似,但针对的是输入序列(encoder input ids),用于在生成序列时避免重复输入内容。输入参数
penalty
(float
): 惩罚系数。encoder_input_ids
(torch.LongTensor
): 输入序列的token IDs。
处理逻辑
- 对于
encoder_input_ids
中的每个token,获取其在scores
中的得分。 - 根据得分的正负,对应地除以或乘以
penalty
。 - 更新
scores
中对应于encoder_input_ids
的位置。
- 对于
6. TopPLogitsWarper
功能:实现核采样(top-p sampling),仅保留累积概率达到
top_p
阈值的tokens,滤除较低概率的tokens。输入参数
top_p
(float
): 累积概率阈值,0 < top_p <= 1
。保留概率最高的tokens,使其累积概率达到top_p
。filter_value
(float
): 被滤除的token的得分设置为该值,通常为-inf
。min_tokens_to_keep
(int
): 至少保留的token数量,不论top_p
的值。
处理逻辑
- 对
scores
进行排序,计算对应的softmax概率并累积(cumulative sum)。 - 找到使累积概率大于
(1 - top_p)
的位置,将这些位置对应的tokens标记为需要滤除。 - 确保至少保留
min_tokens_to_keep
个tokens。 - 将被滤除的tokens的
scores
设置为filter_value
。
- 对
7. TopKLogitsWarper
功能:仅保留概率最高的
top_k
个tokens,其余的tokens被滤除。输入参数
top_k
(int
): 要保留的tokens数量。filter_value
(float
): 被滤除的token的得分设置为该值。min_tokens_to_keep
(int
): 至少保留的token数量。
处理逻辑
- 找到
scores
中得分最高的top_k
个tokens。 - 将其余tokens的
scores
设置为filter_value
。 - 确保至少保留
min_tokens_to_keep
个tokens。
- 找到
8. MinPLogitsWarper
功能:保留概率高于
min_p
的tokens。min_p
根据最高概率的token进行缩放,动态调整阈值。输入参数
min_p
(float
): 最小概率阈值。filter_value
(float
): 被滤除的token的得分设置为该值。min_tokens_to_keep
(int
): 至少保留的token数量。
处理逻辑
- 计算
scores
对应的softmax概率probs
。 - 获取每个batch中最高概率
top_probs
,计算scaled_min_p = min_p * top_probs
。 - 滤除
probs
小于scaled_min_p
的tokens。 - 确保至少保留
min_tokens_to_keep
个tokens。
- 计算
9. TypicalLogitsWarper
功能:实现典型采样(typical sampling),优先选择与整个概率分布的熵接近的tokens,避免选择过于常见或罕见的tokens。
输入参数
mass
(float
): 累积概率阈值,0 < mass < 1
。filter_value
(float
): 被滤除的token的得分设置为该值。min_tokens_to_keep
(int
): 至少保留的token数量。
处理逻辑
- 计算
scores
对应的对数概率normalized
和概率p
。 - 计算熵
entropy = - (normalized * p).sum()
。 - 计算偏移值
shifted_scores = abs(-normalized - entropy)
。 - 对
shifted_scores
进行排序,计算累积概率cumulative_probs
。 - 滤除令累积概率超过
mass
的tokens。 - 确保至少保留
min_tokens_to_keep
个tokens。
- 计算
10. EpsilonLogitsWarper
功能:实现epsilon采样,保留概率大于
epsilon
的tokens。如果没有token满足条件,则保留概率最高的min_tokens_to_keep
个tokens。输入参数
epsilon
(float
): 最小概率阈值。filter_value
(float
): 被滤除的token的得分设置为该值。min_tokens_to_keep
(int
): 至少保留的token数量。
处理逻辑
- 计算
scores
对应的softmax概率probs
。 - 滤除
probs
小于epsilon
的tokens。 - 确保至少保留
min_tokens_to_keep
个tokens。
- 计算
11. EtaLogitsWarper
功能:根据熵动态调整
eta
值,对低概率的tokens进行过滤。eta
的计算与当前分布的熵相关。输入参数
epsilon
(float
): 用于计算动态阈值eta
的参数。filter_value
(float
): 被滤除的token的得分设置为该值。min_tokens_to_keep
(int
): 至少保留的token数量。device
(str
): 计算所使用的设备。
处理逻辑
- 计算
scores
对应的softmax概率probs
。 - 计算熵
entropy
。 - 计算动态阈值
eta = min(epsilon, sqrt(epsilon) * exp(-entropy))
。 - 滤除
probs
小于eta
的tokens。 - 确保至少保留
min_tokens_to_keep
个tokens。
- 计算
12. NoRepeatNGramLogitsProcessor
功能:防止生成重复的n-gram序列。对于已经生成的n-gram,不允许再次生成。
输入参数
ngram_size
(int
): n-gram的大小。例如,ngram_size=2
表示防止重复的二元组。
处理逻辑
- 调用
_calc_banned_ngram_tokens
函数,获取当前时间步每个batch中应被禁止的token列表。 - 将
scores
中对应于被禁止tokens的位置设置为-inf
,以防止生成这些token。
- 调用
13. EncoderNoRepeatNGramLogitsProcessor
功能:防止生成与输入序列中存在的n-gram重复的序列。
输入参数
encoder_ngram_size
(int
): n-gram的大小。encoder_input_ids
(torch.LongTensor
): 输入序列的token IDs。
处理逻辑
- 基于
encoder_input_ids
预计算n-gram序列。 - 对于每个生成的序列,检查是否与输入的n-gram重复。
- 将重复的n-gram对应的
scores
位置设置为-inf
。
- 基于
14. SequenceBiasLogitsProcessor
功能:对指定的token或token序列施加偏置,可以增加或降低其被生成的概率。
输入参数
sequence_bias
(List[List[Union[List[int], float]]]
): 包含序列和对应偏置值的列表。例如,[[[token1_id, token2_id], bias_value]]
。
处理逻辑
- 对于长度为1的序列,直接在
scores
中对应的位置加上偏置值。 - 对于长度大于1的序列,检查前缀是否匹配,如果匹配则对下一个可能完成序列的token加上偏置值。
- 对于长度为1的序列,直接在
15. NoBadWordsLogitsProcessor
功能:禁止生成包含特定“不良”词汇的序列,将这些词汇对应的
scores
设置为-inf
。输入参数
bad_words_ids
(List[List[int]]
): 要禁止的词汇列表,每个词汇由其token ID列表表示。eos_token_id
(Union[int, List[int], torch.Tensor]
): EOS token的ID。
处理逻辑
- 将
bad_words_ids
中的序列转换为SequenceBiasLogitsProcessor
的形式,偏置值设置为-inf
。 - 在生成过程中,确保这些“不良”词汇不被生成。
- 将
16. PrefixConstrainedLogitsProcessor
功能:根据指定的函数
prefix_allowed_tokens_fn
限制每个时间步可能生成的token集合,实现前缀约束。输入参数
prefix_allowed_tokens_fn
(Callable[[int, torch.Tensor], List[int]]
): 一个函数,接收batch_id
和input_ids
,返回允许的token ID列表。num_beams
(int
): beam search的beam大小。
处理逻辑
- 对于每个batch和beam,调用
prefix_allowed_tokens_fn
获取允许的token列表。 - 将不在允许列表中的token的
scores
设置为-inf
。
- 对于每个batch和beam,调用
17. HammingDiversityLogitsProcessor
功能:在group beam search中引入多样性惩罚,鼓励不同beam group生成多样化的序列。
输入参数
diversity_penalty
(float
): 多样性惩罚系数。num_beams
(int
): beam的总数量。num_beam_groups
(int
): beam group的数量。
处理逻辑
- 在生成过程中,每个beam group之间共享信息。
- 对于在当前时间步不同group生成的token,计算它们的频率
token_frequency
。 - 对于在其他group中已生成的token,对其施加惩罚(减少其
scores
)。
18. ForcedBOSTokenLogitsProcessor
功能:强制生成的序列以指定的BOS(开始)token开头。
输入参数
bos_token_id
(int
): 强制生成的开始token的ID。
处理逻辑
- 在生成的第一个时间步(
cur_len == 1
)时,将所有token的scores
设置为-inf
,仅保留bos_token_id
对应的scores
。
- 在生成的第一个时间步(
19. ForcedEOSTokenLogitsProcessor
功能:在达到最大长度
max_length
的倒数第二个时间步时,强制生成EOS(结束)token。输入参数
max_length
(int
): 生成序列的最大长度。eos_token_id
(Union[int, List[int], torch.Tensor]
): EOS token的ID。
处理逻辑
- 如果当前生成的序列长度为
max_length - 1
,则将所有token的scores
设置为-inf
,仅保留eos_token_id
对应的scores
。
- 如果当前生成的序列长度为
20. InfNanRemoveLogitsProcessor
功能:移除
scores
中可能存在的inf
和nan
值,防止计算过程中出现数值错误。处理逻辑
- 将
scores
中值为nan
的位置替换为0.0
。 - 将
scores
中值为正无穷的替换为数据类型的最大值,将负无穷替换为最小值。
- 将
21. ExponentialDecayLengthPenalty
功能:对EOS token的得分进行指数衰减,鼓励模型在适当的长度范围内结束生成。
输入参数
exponential_decay_length_penalty
(Tuple[int, float]
): 包含起始索引start_index
和衰减因子decay_factor
的元组。eos_token_id
(Union[int, List[int], torch.Tensor]
): EOS token的ID。input_ids_seq_length
(int
): 输入序列的长度。
处理逻辑
- 当生成序列长度超过
regulation_start = start_index + input_ids_seq_length
后,逐步增加eos_token_id
的scores
,促进生成结束。 - 增加量为
penalty = abs(scores[:, eos_token_id]) * (decay_factor^(cur_len - regulation_start) - 1)
。
- 当生成序列长度超过
22. LogitNormalization
功能:对
scores
进行log-softmax归一化,确保概率分布的正确性。特别是在应用了一系列LogitsProcessor
后,可能需要重新归一化。处理逻辑
- 对
scores
应用log_softmax
操作:scores = log_softmax(scores, dim=-1)
。
- 对
23. SuppressTokensAtBeginLogitsProcessor
功能:在生成开始时抑制指定的tokens,防止它们在序列的开头被生成。
输入参数
begin_suppress_tokens
(List[int]
): 要在序列开头抑制的token ID列表。begin_index
(int
): 指定从哪个位置开始抑制。
处理逻辑
- 在生成序列长度为
begin_index
时,将scores
中对应于begin_suppress_tokens
的值设置为-inf
。
- 在生成序列长度为
24. SuppressTokensLogitsProcessor
功能:全局抑制指定的tokens,防止它们在任何位置被生成。
输入参数
suppress_tokens
(List[int]
): 要抑制的token ID列表。
处理逻辑
- 将
scores
中对应于suppress_tokens
的值设置为-inf
。
- 将
25. WhisperTimeStampLogitsProcessor
功能:用于Whisper模型,调整时间戳token的
scores
,控制时间戳的生成逻辑,确保时间戳的对称性和有效性。输入参数
generate_config
: 包含Whisper模型特定配置的对象,包括no_timestamps_token_id
、eos_token_id
等。begin_index
(int
): 生成开始的索引位置。
处理逻辑
- 调整时间戳token的
scores
,确保时间戳成对出现,防止非法的时间戳序列。 - 在特定条件下,抑制非时间戳token或强制生成时间戳。
- 调整时间戳token的
26. WhisperNoSpeechDetection
功能:用于Whisper模型,检测无语音(静音)段落,并调整
scores
以反映检测结果。输入参数
no_speech_token
(int
): 表示无语音的token ID。begin_index
(int
): 生成开始的索引位置。
处理逻辑
- 在生成开始时,计算无语音的概率
no_speech_prob
。 - 该概率可用于控制后续的生成过程,例如在检测到静音时,提前终止生成。
- 在生成开始时,计算无语音的概率
27. ClassifierFreeGuidanceLogitsProcessor
功能:实现分类器自由引导(CFG),结合条件和无条件的
scores
,以控制生成文本的倾向性。输入参数
guidance_scale
(float
): 引导系数,大于1时加强对条件的依赖。
处理逻辑
- 将
input_ids
和scores
拆分为条件和无条件两部分。 - 计算处理后的
scores
:scores = uncond_scores + guidance_scale * (cond_scores - uncond_scores)
。
- 将
28. AlternatingCodebooksLogitsProcessor
功能:用于Bark模型的fine子模型,强制生成过程在两个codebook之间交替进行。
输入参数
input_start_len
(int
): 初始输入序列的长度。semantic_vocab_size
(int
): 语义词汇表的大小。codebook_size
(int
): codebook的大小。
处理逻辑
- 根据当前生成的序列长度,确定当前应该使用哪个codebook。
- 在
scores
中抑制不属于当前codebook的tokens。
29. UnbatchedClassifierFreeGuidanceLogitsProcessor
功能:未批量化的CFG处理器,适用于一次只处理一个样本的场景。
输入参数
guidance_scale
(float
): 引导系数。model
: 用于计算无条件logits
的模型实例。unconditional_ids
(torch.LongTensor
): 无条件的输入IDs。
处理逻辑
- 使用
model
计算无条件的logits
。 - 结合条件和无条件的
logits
,根据guidance_scale
调整scores
。
- 使用
30. BarkEosPrioritizerLogitsProcessor
功能:用于Bark模型,在满足特定条件时优先选择EOS token,确保生成序列适时结束。
输入参数
eos_token_id
(Union[int, List[int], torch.Tensor]
): EOS token的ID。min_eos_p
(float
): 最小的EOS token概率阈值。
处理逻辑
- 计算
scores
对应的softmax概率probs
。 - 如果
probs
中eos_token_id
的概率超过min_eos_p
,则强制选择EOS token。
- 计算
31. WatermarkLogitsProcessor
功能:在生成的文本中嵌入水印,通过对特定的“绿色”token添加偏置,实现不可察觉的标记,用于检测生成内容的来源。
输入参数
vocab_size
(int
): 词汇表的大小。device
(str
): 计算设备。greenlist_ratio
(float
): “绿色”token占词汇表的比例。bias
(float
): 添加到“绿色”token的偏置值。
处理逻辑
- 根据当前生成的序列,使用随机生成器确定“绿色”token列表。
- 在
scores
中对应于“绿色”token的位置添加bias
,增加它们被选中的概率。
32. SynthIDTextWatermarkLogitsProcessor
功能:实现SynthID文本水印的logits处理器,在生成过程中微调token的概率以嵌入水印,辅助检测生成内容的真实性。
输入参数
ngram_len
(int
): n-gram的长度,用于生成水印密钥。keys
(List[int]
): 用于水印的密钥序列。sampling_table_size
(int
): 采样表的大小。sampling_table_seed
(int
): 生成采样表的随机种子。
处理逻辑
- 使用哈希函数根据当前生成的n-gram序列计算密钥。
- 根据密钥从采样表中获取g值,调整对应的
scores
,从而嵌入水印。 - 该过程确保嵌入的水印在统计上不可察觉,但可以通过对应的检测算法进行识别。
全局方法详细说明
1. _get_ngrams
功能:获取已生成的指定大小的n-gram序列。
参数
ngram_size
(int
): n-gram的大小。prev_input_ids
(torch.Tensor
): 之前生成的输入IDs。num_hypos
(int
): 假设的数量(batch size)。
返回
generated_ngrams
(dict
): 每个假设对应的已生成n-gram字典。
2. _get_generated_ngrams
功能:确定在当前生成位置应该被禁止的token列表,防止重复生成相同的n-gram。
参数
banned_ngrams
(dict
): 已生成的n-gram字典。prev_input_ids
(torch.Tensor
): 之前生成的输入IDs。ngram_size
(int
): n-gram的大小。cur_len
(int
): 当前生成的序列长度。
返回
banned_tokens
(List[int]
): 应该被禁止的token列表。
3. _calc_banned_ngram_tokens
功能:计算每个假设中应该被禁止的n-gram token列表,防止重复生成。
参数
ngram_size
(int
): n-gram的大小。prev_input_ids
(torch.Tensor
): 之前生成的输入IDs。num_hypos
(int
): 假设的数量。cur_len
(int
): 当前生成的序列长度。
返回
banned_tokens
(List[Iterable[int]]
): 每个假设对应的被禁止token列表。
通过上述处理,各个LogitsProcessor
派生类在生成过程中对scores
进行多样化的调整,以实现特定的生成行为,如避免重复、控制长度、引导话题、嵌入水印等。这些调整确保了生成的文本既符合预期,又具有自然流畅的特点。