transformers - 预测中间词

发布于:2024-04-25 ⋅ 阅读:(84) ⋅ 点赞:(0)

代码


from transformers import AutoTokenizer

#加载编码器
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base', use_fast=True)

print(tokenizer)

#编码试算
tokenizer.batch_encode_plus([
    'hide new secretions from the parental units',
    'contains no wit , only labored gags'
])

PreTrainedTokenizerFast(name_or_path='distilroberta-base', vocab_size=50265, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
{'input_ids': [[0, 37265, 92, 3556, 2485, 31, 5, 20536, 2833, 2], [0, 10800, 5069, 117, 22094, 2156, 129, 6348, 3995, 821, 8299, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

加载数据

from datasets import load_dataset, load_from_disk

#加载数据
dataset = load_dataset(path='glue', name='sst2')
# dataset = load_from_disk('datas/glue/sst2')


#分词,同时删除多余的字段
def f(data):
    return tokenizer.batch_encode_plus(data['sentence'])


dataset = dataset.map(f,
                      batched=True,
                      batch_size=1000,
                      num_proc=4,
                      remove_columns=['sentence', 'idx', 'label'])


#过滤掉太短的句子
def f(data):
    return [len(i) >= 9 for i in data['input_ids']]


dataset = dataset.filter(f, batched=True, batch_size=1000, num_proc=4)


#截断句子,同时整理成模型需要的格式
def f(data):
    b = len(data['input_ids'])
    data['labels'] = data['attention_mask'].copy()
    for i in range(b):
        #裁剪长度到9
        data['input_ids'][i] = data['input_ids'][i][:9]
        data['attention_mask'][i] = [1] * 9
        data['labels'][i] = [-100] * 9

        #input_ids最后一位是2
        data['input_ids'][i][-1] = 2

        #每一句话第4个词为mask
        #tokenizer.get_vocab()['<mask>'] -> 50264
        data['labels'][i][4] = data['input_ids'][i][4]
        data['input_ids'][i][4] = 50264

    return data


dataset = dataset.map(f, batched=True, batch_size=1000, num_proc=4)

dataset, dataset['train'][0]

import torch
from transformers.data.data_collator import default_data_collator

#能够实现随机mask的collate_fn
#如果要使用这个工具类,在数据预处理时就不需要设置数据中的mask,然后让labels=input_ids.copy即可
#from transformers import DataCollatorForLanguageModeling
#data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm_probability=0.1)

#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=default_data_collator,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

len(loader), data

(5534,
 {'input_ids': tensor([[    0, 12196,   128,    29, 50264, 10132,    59,  9326,     2],
          [    0,  1250,     5,  3768, 50264, 34948, 16658,     8,     2],
          [    0,   627,   936,    16, 50264,   240, 12445,  2129,     2],
          [    0,  3654,   350, 13185, 50264,    45,   350,  8794,     2],
          [    0,   560,    28,    56, 50264,  3541, 34261,    19,     2],
          [    0,   560,   224,    14, 50264,   473,   295,    75,     2],
          [    0,     6, 14784,  1054, 50264,    10,   686,   865,     2],
          [    0,  9006,  1495,  2156, 50264, 23317,  4780,     8,     2]]),
  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1]]),
  'labels': tensor([[-100, -100, -100, -100,  144, -100, -100, -100, -100],
          [-100, -100, -100, -100,   32, -100, -100, -100, -100],
          [-100, -100, -100, -100,    5, -100, -100, -100, -100],
          [-100, -100, -100, -100, 2156, -100, -100, -100, -100],
          [-100, -100, -100, -100,   31, -100, -100, -100, -100],
          [-100, -100, -100, -100,   24, -100, -100, -100, -100],
          [-100, -100, -100, -100,   34, -100, -100, -100, -100],
          [-100, -100, -100, -100,   10, -100, -100, -100, -100]])})

from transformers import AutoModelForCausalLM, RobertaModel

#加载模型
#model = AutoModelForCausalLM.from_pretrained('distilroberta-base')


#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = RobertaModel.from_pretrained('distilroberta-base')

        decoder = torch.nn.Linear(768, tokenizer.vocab_size)
        decoder.bias = torch.nn.Parameter(torch.zeros(tokenizer.vocab_size))

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(768, 768),
            torch.nn.GELU(),
            torch.nn.LayerNorm(768, eps=1e-5),
            decoder,
        )

        #加载预训练模型的参数
        parameters = AutoModelForCausalLM.from_pretrained('distilroberta-base')
        self.fc[0].load_state_dict(parameters.lm_head.dense.state_dict())
        self.fc[2].load_state_dict(parameters.lm_head.layer_norm.state_dict())
        self.fc[3].load_state_dict(parameters.lm_head.decoder.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels=None):
        logits = self.pretrained(input_ids=input_ids,
                                 attention_mask=attention_mask)
        logits = logits.last_hidden_state

        logits = self.fc(logits)

        loss = None
        if labels is not None:
            shifted_logits = logits[:, :-1].reshape(-1, tokenizer.vocab_size)
            shifted_labels = labels[:, 1:].reshape(-1)

            loss = self.criterion(shifted_logits, shifted_labels)

        return {'loss': loss, 'logits': logits}


model = Model()

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

out = model(**data)

out['loss'], out['logits'].shape



测试

#测试
def test():
    model.eval()

    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=8,
        collate_fn=default_data_collator,
        shuffle=True,
        drop_last=True,
    )

    correct = 0
    total = 0
    for i, data in enumerate(loader_test):

        #保存下数据中的label,后面计算正确率要用
        label = data['labels'][:, 4].clone()

        #从数据中抹除掉label,防止模型作弊
        data['labels'] = None

        #计算
        with torch.no_grad():
            out = model(**data)

        #[8, 10, 50265] -> [8, 10]
        out = out['logits'].argmax(dim=2)[:, 4]

        correct += (label == out).sum().item()
        total += 8

        if i % 10 == 0:
            print(i)
            print(label)
            print(out)

        if i == 50:
            break

    print(correct / total)

    for i in range(8):
        print(tokenizer.decode(data['input_ids'][i]))
        print(tokenizer.decode(label[i]), tokenizer.decode(out[i]))


test()

0
tensor([   47, 14838,  5392,    28,    80,  4839,  3668,    29])
tensor([   47, 14633,   749,    28,    80,  4839,  3668,  2156])
10
tensor([ 101,  668,   16,   14,  352,  650, 3961,   16])
tensor([ 101,  773, 7897,   59, 2156, 7397, 3961,   16])
20
tensor([40485,    13,    29, 19303,    33,    16,   295,     9])
tensor([40485,    13,  4839, 16393,    33,  3391,   256,     9])
30
tensor([   53, 33469,  3315,  3723,     7, 24473, 40776,    41])
tensor([11248, 15923,  3315,  3723,     7, 24473, 40776,    41])
40
tensor([ 2435,     5,  2046,  2084, 25210,     9, 42661,     7])
tensor([ 2343,    42,  4265,  8003, 33709,  7021,  9021,     6])
50
tensor([  297, 22258,   998,    64,    10,  1499,    65,  2156])
tensor([  457, 22258,  6545,    64,    10, 10416,    65, 33647])
0.32598039215686275
<s>a strong first<mask>, slightly less</s>
 quarter  half
<s>( villene<mask> ) seems to</s>
uve uve
<s>going to the<mask> may be just</s>
 website  gym

from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 50 == 0:
            label = data['labels'][:, 4]
            out = out['logits'].argmax(dim=2)[:, 4]

            correct = (label == out).sum().item()
            accuracy = correct / 8

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), accuracy, lr)

    torch.save(model, 'models/2.预测中间词.model')


train()

/root/anaconda3/envs/cpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,
0 18.949838638305664 0.0 1.9996385977593064e-05
50 4.755198001861572 0.625 1.9815684857246115e-05
100 5.0272216796875 0.25 1.963498373689917e-05
150 4.625316143035889 0.125 1.9454282616552225e-05
200 3.663780927658081 0.5 1.927358149620528e-05
250 2.5342917442321777 0.375 1.909288037585833e-05
300 4.986537933349609 0.375 1.8912179255511386e-05
350 3.403028964996338 0.625 1.873147813516444e-05
400 4.041268348693848 0.125 1.8550777014817495e-05
450 3.2715964317321777 0.5 1.8370075894470547e-05
500 2.6591811180114746 0.5 1.81893747741236e-05
550 4.937175750732422 0.25 1.8008673653776656e-05
600 4.845945835113525 0.25 1.7827972533429708e-05
650 1.8658218383789062 0.625 1.7647271413082763e-05
700 3.9473319053649902 0.25 1.7466570292735818e-05
750 2.065851926803589 0.625 1.728586917238887e-05
800 2.957096576690674 0.5 1.7105168052041924e-05
850 4.987250804901123 0.25 1.692446693169498e-05
900 3.5697021484375 0.5 1.674376581134803e-05
950 2.898092746734619 0.5 1.6563064691001085e-05
1000 4.39031457901001 0.375 1.638236357065414e-05

预测

model = torch.load('models/2.预测中间词.model')
test()

2022-12-08