LogitsProcessor代码分析

发布于:2025-03-31 ⋅ 阅读:(14) ⋅ 点赞:(0)

LogitsProcessor是一个抽象基类,用于在生成序列的过程中对模型输出的logits进行处理。它的派生类实现了各种策略,以控制生成过程。


公共输入和输出

所有的LogitsProcessor派生类都遵循相同的调用约定,即实现了__call__方法,接受以下输入并返回处理后的logits:

  • 输入

    • input_ids (torch.LongTensor): 形状为(batch_size, sequence_length)的张量,表示当前已生成的序列的token索引。
    • scores (torch.FloatTensor): 形状为(batch_size, vocab_size)的张量,表示模型在当前时间步对所有词汇表中每个token的预测得分(通常是logits)。
  • 输出

    • scores (torch.FloatTensor): 形状为(batch_size, vocab_size)的张量,表示经过处理后的预测得分。这些得分随后将用于采样下一个token或选择概率最高的token。

类名/方法名 简要说明
说明
LogitsProcessor 所有在生成过程中可以应用的logits处理器的抽象基类。
LogitsProcessorList 用于创建一个包含多个LogitsProcessor的列表,以便对scores张量进行依次处理。
MinLengthLogitsProcessor 强制生成的序列至少达到最小长度,方法是在达到最小长度之前将EOS(结束)token的概率设置为负无穷大。
MinNewTokensLengthLogitsProcessor 强制新生成的token序列至少达到最小长度,不包括提示(prompt)部分。
TemperatureLogitsWarper 通过应用温度缩放来调节预测中token的概率分布,从而控制生成文本的随机性。
RepetitionPenaltyLogitsProcessor 对已生成的token施加惩罚,防止模型重复生成相同的token序列。
EncoderRepetitionPenaltyLogitsProcessor 对输入(encoder)的token施加惩罚或奖励,鼓励模型重复或避免重复输入中的内容。
TopPLogitsWarper 实现核采样(top-p sampling),仅保留累积概率达到top_p阈值的最可能的token。
TopKLogitsWarper 实现top-k采样,仅保留概率最高的k个token。
MinPLogitsWarper 保留概率高于给定最小值min_p的token,滤除低概率的token。
TypicalLogitsWarper 实现典型采样,优先选择熵较低的token,避免选择过于常见或罕见的词汇。
EpsilonLogitsWarper 实现epsilon采样,保留概率大于epsilon的token,确保选出的token概率超过一定阈值。
EtaLogitsWarper 基于熵动态调整eta阈值,对低概率的token进行过滤,适应性地选择token。
NoRepeatNGramLogitsProcessor 防止生成重复的n-gram序列,通过禁止已生成的n-gram再次出现。
EncoderNoRepeatNGramLogitsProcessor 类似于NoRepeatNGramLogitsProcessor,但针对的是输入(encoder)中已存在的n-gram,防止在生成中重复输入的n-gram。
SequenceBiasLogitsProcessor 对特定的token序列施加加性偏置,可以增加或减少这些序列在生成中的概率。
NoBadWordsLogitsProcessor 禁止生成包含特定“不良”词汇的序列,设置这些词汇的概率为负无穷大。
PrefixConstrainedLogitsProcessor 在生成过程中,强制遵循指定的前缀token列表,限制生成的可能性,以满足特定的前缀约束。
HammingDiversityLogitsProcessor 在beam search中引入哈夫曼多样性惩罚,鼓励生成的序列之间的多样性,避免不同beam生成相似的序列。
ForcedBOSTokenLogitsProcessor 强制生成的序列以指定的开始(BOS)token开头,通常用于encoder-decoder模型。
ForcedEOSTokenLogitsProcessor 在达到最大长度时,强制生成指定的结束(EOS)token,确保序列的终止。
InfNanRemoveLogitsProcessor 移除logits中的无穷大(inf)和非数字(nan)值,避免生成过程中的数值错误。
ExponentialDecayLengthPenalty 对EOS token的得分进行指数衰减,鼓励模型在适当的长度范围内结束生成,避免过长的输出。
LogitNormalization 对logits进行log-softmax归一化,确保概率分布的正确性,常用于beam search后。
SuppressTokensAtBeginLogitsProcessor 在生成开始时抑制特定的token,防止它们在序列的开头被生成。
SuppressTokensLogitsProcessor 全局抑制指定的token,设置它们的logits为负无穷大,防止在任何位置生成。
WhisperTimeStampLogitsProcessor 修改时间戳token的logits,用于控制生成时间戳的行为,常用于Whisper模型的转录任务。
WhisperNoSpeechDetection 检测无语音段落,在检测到静音时,调整logits以反映无语音的状态。
ClassifierFreeGuidanceLogitsProcessor 实现分类器自由引导(CFG),结合有条件和无条件的logits,以控制生成文本的倾向性。
AlternatingCodebooksLogitsProcessor 在Bark模型的fine子模型中交替生成两个codebook,控制生成过程中的模式。
UnbatchedClassifierFreeGuidanceLogitsProcessor 未批量化的CFG logits处理器,适用于非批量处理的模型。
BarkEosPrioritizerLogitsProcessor 对于Bark模型,在满足特定条件时优先选择EOS token,确保序列适时结束。
WatermarkLogitsProcessor 在生成的文本中嵌入水印,通过对特定的“绿色”token添加偏置,实现不可察觉的标记。
SynthIDTextWatermarkState SynthID文本水印的状态类,用于跟踪已生成的token序列和水印相关的信息。
SynthIDTextWatermarkLogitsProcessor 实现SynthID文本水印的logits处理器,在生成过程中微调token的概率以嵌入水印,辅助检测生成内容的真实性。
全局方法 说明
_get_ngrams(ngram_size, prev_input_ids, num_hypos) 获取已生成的n-gram序列,用于防止重复生成。
_get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len) 确定当前生成位置应被禁止的token列表,基于之前生成的n-gram。
_calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len) 计算每个假设(hypothesis)中应被禁止的n-gram token,防止重复。

上述类和方法共同构成了生成文本时调整和控制模型输出的机制,帮助生成更符合预期的、流畅的文本。

派生类详细说明

1. MinLengthLogitsProcessor
  • 功能:在生成序列未达到最小长度时,禁止生成结束符(EOS token),以确保生成的序列至少达到指定的最小长度。

  • 输入参数

    • min_length (int): 最小生成长度。如果当前生成的序列长度小于min_length,则禁止生成EOS token。
    • eos_token_id (Union[int, List[int], torch.Tensor]): EOS token的ID或ID列表。
  • 处理逻辑

    • 检查input_ids的长度cur_len是否小于min_length
    • 如果是,则将scores中对应于eos_token_id的位置设置为负无穷大-inf,以禁止模型在此位置生成EOS token。

2. MinNewTokensLengthLogitsProcessor
  • 功能:与MinLengthLogitsProcessor类似,但只考虑新生成的tokens的长度,不包括初始的提示(prompt)部分。

  • 输入参数

    • prompt_length_to_skip (int): 提示部分的长度,在计算新生成的token长度时跳过此部分。
    • min_new_tokens (int): 新生成的tokens的最小长度。
    • eos_token_id (Union[int, List[int], torch.Tensor]): EOS token的ID或ID列表。
  • 处理逻辑

    • 计算新生成的tokens的长度:new_tokens_length = len(input_ids) - prompt_length_to_skip
    • 如果new_tokens_length小于min_new_tokens,则将scores中对应于eos_token_id的位置设置为-inf,以禁止生成EOS token。

3. TemperatureLogitsWarper
  • 功能:通过温度缩放调节预测分布的平坦程度,影响生成过程的随机性。温度较高时(>1),分布更平坦,随机性更大;温度较低时(<1),分布更尖锐,更倾向于高概率的tokens。

  • 输入参数

    • temperature (float): 温度值,需为正数。调整logits的平坦程度。
  • 处理逻辑

    • scores除以temperaturescores = scores / temperature
    • 温度高于1时,scores较小,softmax后分布更平坦;温度低于1时,scores较大,softmax后分布更尖锐。

4. RepetitionPenaltyLogitsProcessor
  • 功能:对已生成的tokens施加惩罚,防止重复生成相同的tokens。对于已经生成的tokens,如果它们再次出现,其对应的scores会被调整。

  • 输入参数

    • penalty (float): 惩罚系数。大于1表示惩罚重复的tokens,介于0到1之间表示鼓励重复。
  • 处理逻辑

    • 对于每个batch,获取已生成的tokens input_ids在当前scores中的得分。
    • 根据得分的正负,对应地除以或乘以penalty
      • 如果得分<0,则乘以penalty(降低负得分的绝对值)。
      • 如果得分>0,则除以penalty(降低正得分)。
    • 更新scores中对应于已生成tokens的位置。

5. EncoderRepetitionPenaltyLogitsProcessor
  • 功能:与RepetitionPenaltyLogitsProcessor类似,但针对的是输入序列(encoder input ids),用于在生成序列时避免重复输入内容。

  • 输入参数

    • penalty (float): 惩罚系数。
    • encoder_input_ids (torch.LongTensor): 输入序列的token IDs。
  • 处理逻辑

    • 对于encoder_input_ids中的每个token,获取其在scores中的得分。
    • 根据得分的正负,对应地除以或乘以penalty
    • 更新scores中对应于encoder_input_ids的位置。

6. TopPLogitsWarper
  • 功能:实现核采样(top-p sampling),仅保留累积概率达到top_p阈值的tokens,滤除较低概率的tokens。

  • 输入参数

    • top_p (float): 累积概率阈值,0 < top_p <= 1。保留概率最高的tokens,使其累积概率达到top_p
    • filter_value (float): 被滤除的token的得分设置为该值,通常为-inf
    • min_tokens_to_keep (int): 至少保留的token数量,不论top_p的值。
  • 处理逻辑

    • scores进行排序,计算对应的softmax概率并累积(cumulative sum)。
    • 找到使累积概率大于(1 - top_p)的位置,将这些位置对应的tokens标记为需要滤除。
    • 确保至少保留min_tokens_to_keep个tokens。
    • 将被滤除的tokens的scores设置为filter_value

7. TopKLogitsWarper
  • 功能:仅保留概率最高的top_k个tokens,其余的tokens被滤除。

  • 输入参数

    • top_k (int): 要保留的tokens数量。
    • filter_value (float): 被滤除的token的得分设置为该值。
    • min_tokens_to_keep (int): 至少保留的token数量。
  • 处理逻辑

    • 找到scores中得分最高的top_k个tokens。
    • 将其余tokens的scores设置为filter_value
    • 确保至少保留min_tokens_to_keep个tokens。

8. MinPLogitsWarper
  • 功能:保留概率高于min_p的tokens。min_p根据最高概率的token进行缩放,动态调整阈值。

  • 输入参数

    • min_p (float): 最小概率阈值。
    • filter_value (float): 被滤除的token的得分设置为该值。
    • min_tokens_to_keep (int): 至少保留的token数量。
  • 处理逻辑

    • 计算scores对应的softmax概率probs
    • 获取每个batch中最高概率top_probs,计算scaled_min_p = min_p * top_probs
    • 滤除probs小于scaled_min_p的tokens。
    • 确保至少保留min_tokens_to_keep个tokens。

9. TypicalLogitsWarper
  • 功能:实现典型采样(typical sampling),优先选择与整个概率分布的熵接近的tokens,避免选择过于常见或罕见的tokens。

  • 输入参数

    • mass (float): 累积概率阈值,0 < mass < 1
    • filter_value (float): 被滤除的token的得分设置为该值。
    • min_tokens_to_keep (int): 至少保留的token数量。
  • 处理逻辑

    • 计算scores对应的对数概率normalized和概率p
    • 计算熵entropy = - (normalized * p).sum()
    • 计算偏移值shifted_scores = abs(-normalized - entropy)
    • shifted_scores进行排序,计算累积概率cumulative_probs
    • 滤除令累积概率超过mass的tokens。
    • 确保至少保留min_tokens_to_keep个tokens。

10. EpsilonLogitsWarper
  • 功能:实现epsilon采样,保留概率大于epsilon的tokens。如果没有token满足条件,则保留概率最高的min_tokens_to_keep个tokens。

  • 输入参数

    • epsilon (float): 最小概率阈值。
    • filter_value (float): 被滤除的token的得分设置为该值。
    • min_tokens_to_keep (int): 至少保留的token数量。
  • 处理逻辑

    • 计算scores对应的softmax概率probs
    • 滤除probs小于epsilon的tokens。
    • 确保至少保留min_tokens_to_keep个tokens。

11. EtaLogitsWarper
  • 功能:根据熵动态调整eta值,对低概率的tokens进行过滤。eta的计算与当前分布的熵相关。

  • 输入参数

    • epsilon (float): 用于计算动态阈值eta的参数。
    • filter_value (float): 被滤除的token的得分设置为该值。
    • min_tokens_to_keep (int): 至少保留的token数量。
    • device (str): 计算所使用的设备。
  • 处理逻辑

    • 计算scores对应的softmax概率probs
    • 计算熵entropy
    • 计算动态阈值eta = min(epsilon, sqrt(epsilon) * exp(-entropy))
    • 滤除probs小于eta的tokens。
    • 确保至少保留min_tokens_to_keep个tokens。

12. NoRepeatNGramLogitsProcessor
  • 功能:防止生成重复的n-gram序列。对于已经生成的n-gram,不允许再次生成。

  • 输入参数

    • ngram_size (int): n-gram的大小。例如,ngram_size=2表示防止重复的二元组。
  • 处理逻辑

    • 调用_calc_banned_ngram_tokens函数,获取当前时间步每个batch中应被禁止的token列表。
    • scores中对应于被禁止tokens的位置设置为-inf,以防止生成这些token。

13. EncoderNoRepeatNGramLogitsProcessor
  • 功能:防止生成与输入序列中存在的n-gram重复的序列。

  • 输入参数

    • encoder_ngram_size (int): n-gram的大小。
    • encoder_input_ids (torch.LongTensor): 输入序列的token IDs。
  • 处理逻辑

    • 基于encoder_input_ids预计算n-gram序列。
    • 对于每个生成的序列,检查是否与输入的n-gram重复。
    • 将重复的n-gram对应的scores位置设置为-inf

14. SequenceBiasLogitsProcessor
  • 功能:对指定的token或token序列施加偏置,可以增加或降低其被生成的概率。

  • 输入参数

    • sequence_bias (List[List[Union[List[int], float]]]): 包含序列和对应偏置值的列表。例如,[[[token1_id, token2_id], bias_value]]
  • 处理逻辑

    • 对于长度为1的序列,直接在scores中对应的位置加上偏置值。
    • 对于长度大于1的序列,检查前缀是否匹配,如果匹配则对下一个可能完成序列的token加上偏置值。

15. NoBadWordsLogitsProcessor
  • 功能:禁止生成包含特定“不良”词汇的序列,将这些词汇对应的scores设置为-inf

  • 输入参数

    • bad_words_ids (List[List[int]]): 要禁止的词汇列表,每个词汇由其token ID列表表示。
    • eos_token_id (Union[int, List[int], torch.Tensor]): EOS token的ID。
  • 处理逻辑

    • bad_words_ids中的序列转换为SequenceBiasLogitsProcessor的形式,偏置值设置为-inf
    • 在生成过程中,确保这些“不良”词汇不被生成。

16. PrefixConstrainedLogitsProcessor
  • 功能:根据指定的函数prefix_allowed_tokens_fn限制每个时间步可能生成的token集合,实现前缀约束。

  • 输入参数

    • prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): 一个函数,接收batch_idinput_ids,返回允许的token ID列表。
    • num_beams (int): beam search的beam大小。
  • 处理逻辑

    • 对于每个batch和beam,调用prefix_allowed_tokens_fn获取允许的token列表。
    • 将不在允许列表中的token的scores设置为-inf

17. HammingDiversityLogitsProcessor
  • 功能:在group beam search中引入多样性惩罚,鼓励不同beam group生成多样化的序列。

  • 输入参数

    • diversity_penalty (float): 多样性惩罚系数。
    • num_beams (int): beam的总数量。
    • num_beam_groups (int): beam group的数量。
  • 处理逻辑

    • 在生成过程中,每个beam group之间共享信息。
    • 对于在当前时间步不同group生成的token,计算它们的频率token_frequency
    • 对于在其他group中已生成的token,对其施加惩罚(减少其scores)。

18. ForcedBOSTokenLogitsProcessor
  • 功能:强制生成的序列以指定的BOS(开始)token开头。

  • 输入参数

    • bos_token_id (int): 强制生成的开始token的ID。
  • 处理逻辑

    • 在生成的第一个时间步(cur_len == 1)时,将所有token的scores设置为-inf,仅保留bos_token_id对应的scores

19. ForcedEOSTokenLogitsProcessor
  • 功能:在达到最大长度max_length的倒数第二个时间步时,强制生成EOS(结束)token。

  • 输入参数

    • max_length (int): 生成序列的最大长度。
    • eos_token_id (Union[int, List[int], torch.Tensor]): EOS token的ID。
  • 处理逻辑

    • 如果当前生成的序列长度为max_length - 1,则将所有token的scores设置为-inf,仅保留eos_token_id对应的scores

20. InfNanRemoveLogitsProcessor
  • 功能:移除scores中可能存在的infnan值,防止计算过程中出现数值错误。

  • 处理逻辑

    • scores中值为nan的位置替换为0.0
    • scores中值为正无穷的替换为数据类型的最大值,将负无穷替换为最小值。

21. ExponentialDecayLengthPenalty
  • 功能:对EOS token的得分进行指数衰减,鼓励模型在适当的长度范围内结束生成。

  • 输入参数

    • exponential_decay_length_penalty (Tuple[int, float]): 包含起始索引start_index和衰减因子decay_factor的元组。
    • eos_token_id (Union[int, List[int], torch.Tensor]): EOS token的ID。
    • input_ids_seq_length (int): 输入序列的长度。
  • 处理逻辑

    • 当生成序列长度超过regulation_start = start_index + input_ids_seq_length后,逐步增加eos_token_idscores,促进生成结束。
    • 增加量为penalty = abs(scores[:, eos_token_id]) * (decay_factor^(cur_len - regulation_start) - 1)

22. LogitNormalization
  • 功能:对scores进行log-softmax归一化,确保概率分布的正确性。特别是在应用了一系列LogitsProcessor后,可能需要重新归一化。

  • 处理逻辑

    • scores应用log_softmax操作:scores = log_softmax(scores, dim=-1)

23. SuppressTokensAtBeginLogitsProcessor
  • 功能:在生成开始时抑制指定的tokens,防止它们在序列的开头被生成。

  • 输入参数

    • begin_suppress_tokens (List[int]): 要在序列开头抑制的token ID列表。
    • begin_index (int): 指定从哪个位置开始抑制。
  • 处理逻辑

    • 在生成序列长度为begin_index时,将scores中对应于begin_suppress_tokens的值设置为-inf

24. SuppressTokensLogitsProcessor
  • 功能:全局抑制指定的tokens,防止它们在任何位置被生成。

  • 输入参数

    • suppress_tokens (List[int]): 要抑制的token ID列表。
  • 处理逻辑

    • scores中对应于suppress_tokens的值设置为-inf

25. WhisperTimeStampLogitsProcessor
  • 功能:用于Whisper模型,调整时间戳token的scores,控制时间戳的生成逻辑,确保时间戳的对称性和有效性。

  • 输入参数

    • generate_config: 包含Whisper模型特定配置的对象,包括no_timestamps_token_ideos_token_id等。
    • begin_index (int): 生成开始的索引位置。
  • 处理逻辑

    • 调整时间戳token的scores,确保时间戳成对出现,防止非法的时间戳序列。
    • 在特定条件下,抑制非时间戳token或强制生成时间戳。

26. WhisperNoSpeechDetection
  • 功能:用于Whisper模型,检测无语音(静音)段落,并调整scores以反映检测结果。

  • 输入参数

    • no_speech_token (int): 表示无语音的token ID。
    • begin_index (int): 生成开始的索引位置。
  • 处理逻辑

    • 在生成开始时,计算无语音的概率no_speech_prob
    • 该概率可用于控制后续的生成过程,例如在检测到静音时,提前终止生成。

27. ClassifierFreeGuidanceLogitsProcessor
  • 功能:实现分类器自由引导(CFG),结合条件和无条件的scores,以控制生成文本的倾向性。

  • 输入参数

    • guidance_scale (float): 引导系数,大于1时加强对条件的依赖。
  • 处理逻辑

    • input_idsscores拆分为条件和无条件两部分。
    • 计算处理后的scoresscores = uncond_scores + guidance_scale * (cond_scores - uncond_scores)

28. AlternatingCodebooksLogitsProcessor
  • 功能:用于Bark模型的fine子模型,强制生成过程在两个codebook之间交替进行。

  • 输入参数

    • input_start_len (int): 初始输入序列的长度。
    • semantic_vocab_size (int): 语义词汇表的大小。
    • codebook_size (int): codebook的大小。
  • 处理逻辑

    • 根据当前生成的序列长度,确定当前应该使用哪个codebook。
    • scores中抑制不属于当前codebook的tokens。

29. UnbatchedClassifierFreeGuidanceLogitsProcessor
  • 功能:未批量化的CFG处理器,适用于一次只处理一个样本的场景。

  • 输入参数

    • guidance_scale (float): 引导系数。
    • model: 用于计算无条件logits的模型实例。
    • unconditional_ids (torch.LongTensor): 无条件的输入IDs。
  • 处理逻辑

    • 使用model计算无条件的logits
    • 结合条件和无条件的logits,根据guidance_scale调整scores

30. BarkEosPrioritizerLogitsProcessor
  • 功能:用于Bark模型,在满足特定条件时优先选择EOS token,确保生成序列适时结束。

  • 输入参数

    • eos_token_id (Union[int, List[int], torch.Tensor]): EOS token的ID。
    • min_eos_p (float): 最小的EOS token概率阈值。
  • 处理逻辑

    • 计算scores对应的softmax概率probs
    • 如果probseos_token_id的概率超过min_eos_p,则强制选择EOS token。

31. WatermarkLogitsProcessor
  • 功能:在生成的文本中嵌入水印,通过对特定的“绿色”token添加偏置,实现不可察觉的标记,用于检测生成内容的来源。

  • 输入参数

    • vocab_size (int): 词汇表的大小。
    • device (str): 计算设备。
    • greenlist_ratio (float): “绿色”token占词汇表的比例。
    • bias (float): 添加到“绿色”token的偏置值。
  • 处理逻辑

    • 根据当前生成的序列,使用随机生成器确定“绿色”token列表。
    • scores中对应于“绿色”token的位置添加bias,增加它们被选中的概率。

32. SynthIDTextWatermarkLogitsProcessor
  • 功能:实现SynthID文本水印的logits处理器,在生成过程中微调token的概率以嵌入水印,辅助检测生成内容的真实性。

  • 输入参数

    • ngram_len (int): n-gram的长度,用于生成水印密钥。
    • keys (List[int]): 用于水印的密钥序列。
    • sampling_table_size (int): 采样表的大小。
    • sampling_table_seed (int): 生成采样表的随机种子。
  • 处理逻辑

    • 使用哈希函数根据当前生成的n-gram序列计算密钥。
    • 根据密钥从采样表中获取g值,调整对应的scores,从而嵌入水印。
    • 该过程确保嵌入的水印在统计上不可察觉,但可以通过对应的检测算法进行识别。

全局方法详细说明

1. _get_ngrams
  • 功能:获取已生成的指定大小的n-gram序列。

  • 参数

    • ngram_size (int): n-gram的大小。
    • prev_input_ids (torch.Tensor): 之前生成的输入IDs。
    • num_hypos (int): 假设的数量(batch size)。
  • 返回

    • generated_ngrams (dict): 每个假设对应的已生成n-gram字典。

2. _get_generated_ngrams
  • 功能:确定在当前生成位置应该被禁止的token列表,防止重复生成相同的n-gram。

  • 参数

    • banned_ngrams (dict): 已生成的n-gram字典。
    • prev_input_ids (torch.Tensor): 之前生成的输入IDs。
    • ngram_size (int): n-gram的大小。
    • cur_len (int): 当前生成的序列长度。
  • 返回

    • banned_tokens (List[int]): 应该被禁止的token列表。

3. _calc_banned_ngram_tokens
  • 功能:计算每个假设中应该被禁止的n-gram token列表,防止重复生成。

  • 参数

    • ngram_size (int): n-gram的大小。
    • prev_input_ids (torch.Tensor): 之前生成的输入IDs。
    • num_hypos (int): 假设的数量。
    • cur_len (int): 当前生成的序列长度。
  • 返回

    • banned_tokens (List[Iterable[int]]): 每个假设对应的被禁止token列表。

通过上述处理,各个LogitsProcessor派生类在生成过程中对scores进行多样化的调整,以实现特定的生成行为,如避免重复、控制长度、引导话题、嵌入水印等。这些调整确保了生成的文本既符合预期,又具有自然流畅的特点。


网站公告

今日签到

点亮在社区的每一天
去签到