GPT生成文本

发布于:2023-05-01 ⋅ 阅读:(269) ⋅ 点赞:(0)
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# 可选:如果您想了解发生的信息,请按以下步骤logger
import logging
logging.basicConfig(level=logging.INFO)

# 加载预训练模型(权重)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 编码输入
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text)

# 转换为PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens])
# 让我们看看如何使用GPT2LMHeadModel生成下一个跟在我们的文本后面的token:

# 加载预训练模型(权重)
model = GPT2LMHeadModel.from_pretrained('gpt2')

# 将模型设置为评估模式
# 在评估期间有可再现的结果这是很重要的!
model.eval()

# 如果你有GPU,把所有东西都放在cuda上
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda')

# 预测所有标记
with torch.no_grad():
    outputs = model(tokens_tensor)#[1, 11, 50257],[2, 1, 12, 11, 64]
    predictions = outputs[0]
# 得到预测的下一个子词(在我们的例子中,是“man”这个词)
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
# assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'
# 每个模型架构(Bert、GPT、GPT-2、Transformer XL、XLNet和XLM)的每个模型类的示例,可以在文档中找到。

# 使用过去的GPT-2
# 以及其他一些模型(GPT、XLNet、Transfo XL、CTRL),使用past或mems属性,这些属性可用于防止在使用顺序解码时重新计算键/值对。它在生成序列时很有用,因为注意力机制的很大一部分得益于以前的计算。
#
# 下面是一个使用带past的GPT2LMHeadModel和argmax解码的完整工作示例(只能作为示例,因为argmax decoding引入了大量重复):

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')

generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None

for i in range(100):
    output, past = model(context, past=past)
    token = torch.argmax(output[..., -1, :])

    generated += [token.tolist()]
    context = token.unsqueeze(0)

sequence = tokenizer.decode(generated)

print(sequence)

 


网站公告

今日签到

点亮在社区的每一天
去签到