系列文章目录
终章 1:Attention的结构
终章 2:带Attention的seq2seq的实现
终章 3:Attention的评价
终章 4:关于Attention的其他话题
终章 5:Attention的应用
文章目录
目录
前言
下面,我们使用上一节实现的AttentionSeq2seq类来挑战一个实际问题。 我们通过研究“日期格式转换”问题(本质上属于人为创造的问题,数据量有限),来确认带Attention的seq2seq 的效果。
一、日期格式转换问题
这里我们要处理的是日期格式转换问题。这个任务旨在将使用英语的国家和地区所使用的各种各样的日期格式转换为标准格式。例如,将人写的 “september 27, 1994”这样的日期数据转换为“1994-09-27”这样的标准格式,如下图所示。
这里采用日期格式转换问题的原因有两个。首先,该问题并不像看上去那么简单。因为输入的日期数据存在各种各样的版本,所以转换规则也相应 地复杂。如果尝试将这些转换规则全部写出来,那将非常费力。 其次,该问题的输入(问句)和输出(回答)存在明显的对应关系。具 体而言,存在年月日的对应关系。因此,我们可以确认Attention有没有正 确地关注各自的对应元素。 事先在dataset/date.txt中准备好了要处理的日期转换数据。如下图所示,这个文本文件包含50 000个日期转换用的学习数据。
为了对齐输入语句的长度,本书提供的日期数据集填充了空格,并将 “_”(下划线)设置为输入和输出的分隔符。另外,因为这个问题输出的字 符数是恒定的,所以无须使用分隔符来指示输出的结束。
二、带Attention的seq2seq的学习
下面,我们在日期转换用的数据集上进行AttentionSeq2seq的学习,学习用的代码如下所示
# coding: utf-8
import sys
sys.path.append('..')
import numpy as np
import matplotlib.pyplot as plt
from dataset import sequence
from common.optimizer import Adam
from common.trainer import Trainer
from common.util import eval_seq2seq
from attention_seq2seq import AttentionSeq2seq
from ch07.seq2seq import Seq2seq
from ch07.peeky_seq2seq import PeekySeq2seq
# 读入数据
(x_train, t_train), (x_test, t_test) = sequence.load_data('date.txt')
char_to_id, id_to_char = sequence.get_vocab()
# 反转输入语句
x_train, x_test = x_train[:, ::-1], x_test[:, ::-1]
# 设定超参数
vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 256
batch_size = 128
max_epoch = 10
max_grad = 5.0
model = AttentionSeq2seq(vocab_size, wordvec_size, hidden_size)
# model = Seq2seq(vocab_size, wordvec_size, hidden_size)
# model = PeekySeq2seq(vocab_size, wordvec_size, hidden_size)
optimizer = Adam()
trainer = Trainer(model, optimizer)
acc_list = []
for epoch in range(max_epoch):
trainer.fit(x_train, t_train, max_epoch=1,
batch_size=batch_size, max_grad=max_grad)
correct_num = 0
for i in range(len(x_test)):
question, correct = x_test[[i]], t_test[[i]]
verbose = i < 10
correct_num += eval_seq2seq(model, question, correct,
id_to_char, verbose, is_reverse=True)
acc = float(correct_num) / len(x_test)
acc_list.append(acc)
print('val acc %.3f%%' % (acc * 100))
model.save_params()
# 绘制图形
x = np.arange(len(acc_list))
plt.plot(x, acc_list, marker='o')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.ylim(-0.05, 1.05)
plt.show()
这里显示的代码和上一章的加法问题的学习用代码几乎一样。区别在 于,它读入日期数据作为学习数据,使用AttentionSeq2seq作为模型。另外, 这里还使用了反转输入语句的技巧(Reverse)。之后,在学习的同时,每个 epoch 使用测试数据计算正确率。为了查看结果,我们将前10个问题的问句和回答输出到终端。 现在我们运行一下上面的代码。随着学习的进行,部分结果如下图(以下第一轮的结果不好,但后面慢慢变好):
如上图所示,随着学习的深入,带Attention的seq2seq变聪明了(我电脑上耗时10分钟左右)。 实际上,没过多久,它就对大多数问题给出了正确答案(第二轮开始)。此时,测试数据的正确率(代码中的acc_list)
如上图所示,从第1个epoch开始,正确率迅速上升,到第2个epoch 时,几乎可以正确回答所有问题。这可以说是一个很好的结果。我们将这个结果与上一章的模型比较一下,如下图所示
从上图的结果可知,简单的seq2seq(图中的baseline)完全没法用。 即使经过了10个epoch,大多数问题还是不能回答正确。而使用了“偷窥” 技术的Peeky给出了良好的结果,从第3个epoch开始,模型的正确率开始上升,在第4个epoch时,正确率达到了100%。但是,就学习速度而言, Attention 稍微有些优势。 在这次的实验中,就最终精度来看,Attention和Peeky取得了差不多的结果。但是,随着时序数据变长、变复杂,除了学习速度之外, Attention 在精度上也会变得更有优势。
三、Attention的可视化
接下来,我们对Attention进行可视化。在进行时序转换时,实际观察 Attention 在注意哪个元素。因为在Attention层中,各个时刻的Attention 权重均保存到了成员变量中,所以我们可以轻松地进行可视化。代码如下:
# coding: utf-8
import sys
sys.path.append('..')
import numpy as np
from dataset import sequence
import matplotlib.pyplot as plt
from attention_seq2seq import AttentionSeq2seq
(x_train, t_train), (x_test, t_test) = \
sequence.load_data('date.txt')
char_to_id, id_to_char = sequence.get_vocab()
# Reverse input
x_train, x_test = x_train[:, ::-1], x_test[:, ::-1]
vocab_size = len(char_to_id)
wordvec_size = 16
hidden_size = 256
model = AttentionSeq2seq(vocab_size, wordvec_size, hidden_size)
model.load_params()
_idx = 0
def visualize(attention_map, row_labels, column_labels):
fig, ax = plt.subplots()
ax.pcolor(attention_map, cmap=plt.cm.Greys_r, vmin=0.0, vmax=1.0)
ax.patch.set_facecolor('black')
ax.set_yticks(np.arange(attention_map.shape[0])+0.5, minor=False)
ax.set_xticks(np.arange(attention_map.shape[1])+0.5, minor=False)
ax.invert_yaxis()
ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
global _idx
_idx += 1
plt.show()
np.random.seed(1984)
for _ in range(5):
idx = [np.random.randint(0, len(x_test))]
x = x_test[idx]
t = t_test[idx]
model.forward(x, t)
d = model.decoder.attention.attention_weights
d = np.array(d)
attention_map = d.reshape(d.shape[0], d.shape[2])
# reverse for print
attention_map = attention_map[:,::-1]
x = x[:,::-1]
row_labels = [id_to_char[i] for i in x[0]]
column_labels = [id_to_char[i] for i in t[0]]
column_labels = column_labels[1:]
visualize(attention_map, row_labels, column_labels)
在我们的实现中,Time Attention层中的成员变量attention_weights 保 存了各个时刻的Attention权重,据此可以将输入语句和输出语句的各个单词 的对应关系绘制成一张二维地图。这里,我们针对学习好的AttentionSeq2seq, 对进行日期格式转换时的Attention权重进行可视化。将结果显示在下图上:
上图是seq2seq 进行时序转换时的Attention权重的可视化结果。例如,我们可以看到,当seq2seq输出第1个“1”时,注意力集中在输入语句的“1”上。这里需要特别注意年月日的对应关系。仔细观察图中的结果,纵轴(输出)的“1983”和“26”恰好对应于横轴(输入)的“1983”和“26”。 另外,输入语句的“AUGUST”对应于表示月份的“08”,这一点也很令人惊讶。这表明seq2seq从数据中学习到了“August”和“8月”的对应关系。下图中给出了其他一些例子,从中也可以很清楚地看到年月日的对应关系。
像这样,使用Attention,seq2seq能像我们人一样将注意力集中在必 要的信息上。换言之,借助Attention,我们理解了模型是如何工作的。
总结
以上就是关于Attention的评价的内容。通过这里的实验,我们体验了 Attention的奇妙效果。至此,Attention的核心话题就要告一段落了,但是关于Attention的其他内容还有不少。下一次我们继续围绕Attention, 给出它的几个高级技巧。