PyTorch深度学习实战(43)——手写文本识别
0. 前言
手写文本识别,也称为手写文本的光学字符识别 (Optical Character Recognition
, OCR
),是计算机视觉和自然语言处理中的一项具有挑战性的任务。与印刷文本不同,手写文本在风格、大小和质量方面变化巨大,这使得识别和转录变得更加困难。手写文本识别的目标是准确地识别和转录手写文本,以便进行进一步的分析、存储或处理。我们已经学习了如何根据输入图像生成描述性文本单词序列,在本节中,我们将学习如何根据手写文字图像输入生成字符序列,为了提高手写图像的转录性能,将引入 CTC
损失函数。
1. 手写文本识别
1.1 基本概念
手写文本识别与图像字幕生成不同,图像字幕生成模型中所用图像的内容与输出单词之间没有直接的相关性,而手写图像中的字符序列与输出序列之间存在直接相关性。因此,图像字幕生成模型架构并不适用于手写文本识别模型,需要设计不同的架构。
假设一张图像被分成 20
个部分(假设一个图像中每个单词最多包含 20
个字符),其中每个部分(在循环神经网络中每个部分可以作为一个时间步的输入)对应一个字符。在手写文本图像中,有些笔迹可能会确保每个字符完全对应每个部分,而有些笔迹可能较为混乱,使得每个部分包含不同数量的字符,或者可能导致两个字符之间的间距太大以至于无法将一个单词适配到 20
个部分中。为了解决这些问题,引入了 Connectionist temporal classification
(CTC
) 损失函数。
1.2 输入和输出格式
假设我们需要识别包含文本 ab
的图像。示例图像如下,字符 a
和 b
之间的具有不同长度的空格,但输出标签均为 ab
:
我们可以将这些图像样本分割为多个时间步,如下所示,其中每个方框代表一个时间步,因此可以看到共有六个时间步:
预测每个时间步的输出字符,其中每个时间步的 softmax
输出是整个词汇表中每个字母的类别概率,则第一张关于 ab
图片的每个时间步的输出如下:
在上图中的 -
表示空白。此外,如果图像的特征通过双向长短时记忆网络 (Long Short-Term Memory
, LSTM
) 传递,第 3
和第 4
时间步的输出可能均为 b
,因为在执行双向 LSTM
时,下一个时间步中的信息也会影响上一个时间步的输出。在最后一步中,压缩所有在连续时间步中具有相同值的 softmax
输出,因此此样本最终输出为:-a-b-
。
如果存在连续的相同字符预测,则压缩重复字符的输出,最终输出如下所示:
-a-b-
而如果输出为 abb
时,则在压缩后的最终输出结果中需要在两个 b
字符之间添加一个分隔符:
-a-b-b-
1.3 CTC 损失值
如果要计算 CTC
损失值,我们考虑下图中的情形,图中的圆圈中提供了在给定时间步内不同字符类别的概率,可以看到,在从 t0
到 t5
的每个时间步内概率之和均为 1
:
为了简单起见,我们考虑以下情况:图片标签为 a
而不是 ab
,且输出只有 3
个时间步而不是 6
个时间步,输出结果如下所示:
下表列出了在每个时间步中的经过 softmax
激活函数后的输出概率,我们都可以得到输出标签 a
:
每个时间步的输出 | 时间步1中的字符概率 | 时间步2中的字符概率 | 时间步3中的字符概率 | 组合概率 | 最终概率 |
---|---|---|---|---|---|
–a | 0.8 | 0.1 | 0.1 | 0.8x0.1x0.1 | 0.008 |
-aa | 0.8 | 0.9 | 0.1 | 0.8 x 0.9 x 0.1 | 0.072 |
aaa | 0.2 | 0.9 | 0.1 | 0.2 x 0.9 x 0.1 | 0.018 |
-a- | 0.8 | 0.9 | 0.8 | 0.8 x 0.9 x 0.8 | 0.576 |
a-a | 0.8 | 0.9 | 0.1 | 0.8 x 0.9 x 0.1 | 0.072 |
a– | 0.2 | 0.9 | 0.8 | 0.2 x 0.1 x 0.8 | 0.016 |
aa- | 0.2 | 0.1 | 0.8 | 0.2 x 0.9 x 0.8 | 0.144 |
总概率 | - | - | - | - | 0.906 |
从前面的结果中,我们可以获得标签 a
的总概率为 0.906
,CTC
损失是总概率的负对数,即 − l o g ( 0.906 ) = 0.04 -log(0.906)= 0.04 −log(0.906)=0.04。由于在每个时间戳中具有最高概率的字符的组合预测了标签 a
,因此 CTC
损失接近于零。
2. 模型与数据集分析
2.1 数据集分析
本文使用 IAM
手写数据集训练手写文字文本模型,IAM
手写数据集包含手写英文文本,可用于训练和测试手写文本识别模型。该数据集中包含不同类型的手写文本形式,这些文本是 300dpi
分辨率的扫描件,并保存为 256
级灰度 PNG
图像,下图是一些数据集中的样本图片:
数据集中的字符是使用自动分割算法从扫描件中提取,并经过人工验证。同时,数据集 xml.tgz
中包含 XML
文件,每个 XML
文件都记录了一系列手写文本图片的相关信息,包括文件名、图片中的字符等。
该数据集可从以下链接下载:https://pan.baidu.com/s/1sr8ZMCxoNKymZRrRMX75wg,提取码: 6znj
。
2.2 模型分析
在实现手写文本识别模型前,我们首先介绍用于转录手写文本图像的模型策略流程:
- 导入图像数据集及其对应的文本标签
- 为每个字符分配一个索引
- 通过卷积神经网络获取输入图像对应的特征图
- 通过循环神经网络 (
Recurrent Neural Network
,RNN
) 传递特征图 - 获取每个时间步的概率
- 利用
CTC
损失函数压缩输出并获取文本标签和相应的损失 - 通过最小化
CTC
损失函数优化网络权重
3. 实现手写文本识别模型
接下来,我们使用 PyTorch
实现上一小节介绍的手写文本识别模型。
(1) 下载并解压文本图片和 XML
标注数据集,其中包含了手写文本的图像及其相应的标签数据。图像示例如下:
(2) 导入所需库:
from torchsummary import summary
import editdistance
import torch
from glob import glob
import numpy as np
from torch.utils.data import DataLoader, Dataset
import random
import cv2
from torch import nn, optim
from matplotlib import pyplot as plt
(3) 指定图像的位置并获取图像对应的文本标签:
def stem(filename):
filename = filename.split('/')[-1]
return filename.split('.')[0]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
fname2label = lambda fname: stem(fname).split('@')[0]
images = glob('synthetic-data/synthetic-data/*.png')
在以上代码中,创建 fname2label
函数,图像的文本标签在文件名中的 @
符号之前。文件名示例如下:
(4) 定义字符的词汇表 (vocab
)、批大小 (B
)、RNN
的时间步长 (T
)、词汇表的长度 (V
)、图片的高度 (H
) 和宽度 (W
):
vocab = 'QWERTYUIOPASDFGHJKLZXCVBNMqwertyuiopasdfghjklzxcvbnm'
B,T,V = 64, 32, len(vocab)
H,W = 32, 128
(5) 定义数据集类 OCRDataset
。
定义 __init__
方法,通过循环遍历 vocab
指定字符到字符 ID
的映射 (charList
) 及其反向映射 (invCharList
),以及时间步数 (timesteps
) 和要获取的图像文件路径 (item
) 。使用 charList
和 invCharList
而非使用 torchtext
的构建词汇表,因为词汇表更易于处理(包含较少数量的不同字符):
class OCRDataset(Dataset):
def __init__(self, items, vocab=vocab, preprocess_shape=(H,W), timesteps=T):
super().__init__()
self.items = items
self.charList = {ix+1:ch for ix,ch in enumerate(vocab)}
self.charList.update({0: '`'})
self.invCharList = {v:k for k,v in self.charList.items()}
self.ts = timesteps
定义 __len__
和 __getitem__
方法:
def __len__(self):
return len(self.items)
def sample(self):
return self[random.randint(0, len(self))]
def __getitem__(self, ix):
item = self.items[ix]
image = cv2.imread(item, 0)
label = fname2label(item)
return image, label
在 __getitem__
方法中,读取图像并使用 fname2label
函数创建标签。此外,定义采样方法 sample
,用于从数据集中随机采样图像。
定义 collate_fn
方法,接受批数据图像并将它们及其标签添加到不同列表中。此外,它将与图像相对应的真实值字符转换为向量格式(将每个字符转换为其对应的 ID
),最后存储每个图像的标签长度和输入长度(时间步数)。在计算损失值时,CTC
损失函数会利用标签长度和输入长度:
def collate_fn(self, batch):
images, labels, label_lengths, label_vectors, input_lengths = [], [], [], [], []
for image, label in batch:
images.append(torch.Tensor(self.preprocess(image))[None,None])
label_lengths.append(len(label))
labels.append(label)
label_vectors.append(self.str2vec(label))
input_lengths.append(self.ts)
将上述列表转换为 Torch
张量对象并返回 images
、labels
、label_lengths
、label_vectors
和 input_lengths
:
images = torch.cat(images).float().to(device)
label_lengths = torch.Tensor(label_lengths).long().to(device)
label_vectors = torch.Tensor(label_vectors).long().to(device)
input_lengths = torch.Tensor(input_lengths).long().to(device)
return images, label_vectors, label_lengths, input_lengths, labels
定义 str2vec
函数,将字符 ID
的输入转换为字符串:
def str2vec(self, string, pad=True):
string = ''.join([s for s in string if s in self.invCharList])
val = list(map(lambda x: self.invCharList[x], string))
if pad:
while len(val) < self.ts:
val.append(0)
return val
在 str2vec
函数中,从字符 ID
串中获取字符,如果标签的长度 (len(val)
) 小于时间步 (self.ts
),使用零填充索引将向量附加到标签向量中。
定义预处理函数,该函数将图像 (img
) 和形状 shape
作为输入,将其处理为 32 x 128
的统一形状,除了调整图像大小之外,还需要进行额外的预处理,因为要在保持纵横比的前提下需要调整图像大小。
定义 preprocess
函数以及图像的目标形状,图像初始化为空白图像 target
:
def preprocess(self, img, shape=(32,128)):
target = np.ones(shape)*255
获取图像的形状和预期形状:
try:
H, W = shape
h, w = img.shape
调整图像大小以保持纵横比:
fx = H/h
fy = W/w
f = min(fx, fy)
_h = int(h*f)
_w = int(w*f)
调整图像大小并将其存储在 target
中:
_img = cv2.resize(img, (_w,_h))
target[:_h,:_w] = _img
返回标准化图像,首先将图像转换为黑色背景,然后将像素值缩放到 0
到 1
之间:
except:
pass
return (255-target)/255
定义 decoder_chars
函数将预测解码为单词:
def decoder_chars(self, pred):
decoded = ""
last = ""
pred = pred.cpu().detach().numpy()
for i in range(len(pred)):
k = np.argmax(pred[i])
if k > 0 and self.charList[k] != last:
last = self.charList[k]
decoded = decoded + last
elif k > 0 and self.charList[k] == last:
continue
else:
last = ""
return decoded.replace(" "," ")
在以上代码中,我们一次一个时间步循环遍历预测 (pred
),获取置信度最高的字符 (k
),并将其与前一个时间步中置信度最高的字符 (last
) 进行比较,如果上一个时间步中置信度最高的字符与当前时间步中置信度最高的字符不同,则将其附加到已解码的字符中。
定义计算字符准确率和单词准确率的方法:
def wer(self, preds, labels):
c = 0
for p, l in zip(preds, labels):
c += p.lower().strip() != l.lower().strip()
return round(c/len(preds), 4)
def cer(self, preds, labels):
c, d = [], []
for p, l in zip(preds, labels):
c.append(editdistance.eval(p, l) / len(l))
return round(np.mean(c), 4)
定义 evaluate
方法,在一组图像上评估模型并返回单词错误率和字符错误率:
def evaluate(self, model, ims, labels, lower=False):
model.eval()
preds = model(ims).permute(1,0,2) # B, T, V+1
preds = [self.decoder_chars(pred) for pred in preds]
return {'char-error-rate': self.cer(preds, labels),
'word-error-rate': self.wer(preds, labels),
'char-accuracy' : 1 - self.cer(preds, labels),
'word-accuracy' : 1 - self.wer(preds, labels)}
在以上代码中,对输入图像的通道进行排列,以便将数据预处理为模型所期望的输入格式,使用 decoder_chars
函数对预测进行解码,然后返回字符错误率、单词错误率及其相应的准确率。
(6) 定义训练和验证数据集以及数据加载器:
from sklearn.model_selection import train_test_split
trn_items, val_items = train_test_split(glob('synthetic-data/synthetic-data/*.png'), test_size=0.2, random_state=22)
trn_ds = OCRDataset(trn_items)
val_ds = OCRDataset(val_items)
trn_dl = DataLoader(trn_ds, batch_size=B, collate_fn=trn_ds.collate_fn, drop_last=True, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=B, collate_fn=val_ds.collate_fn, drop_last=True)
(7) 构建网络模型。
构建卷积神经网络 (Convolutional Neural Networks
, CNN
) 的基本块:
class BasicBlock(nn.Module):
def __init__(self, ni, no, ks=3, st=1, padding=1, pool=2, drop=0.2):
super().__init__()
self.ks = ks
self.block = nn.Sequential(
nn.Conv2d(ni, no, kernel_size=ks, stride=st, padding=padding),
nn.BatchNorm2d(no, momentum=0.3),
nn.ReLU(inplace=True),
nn.MaxPool2d(pool),
nn.Dropout2d(drop)
)
def forward(self, x):
return self.block(x)
class Reshape(nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
class Permute(nn.Module):
def __init__(self, *order):
super().__init__()
self.order = order
def forward(self, x):
return x.permute(*self.order)
构建神经网络类 OCR
,其中 CNN
块和 RNN
块分别在 self.model
和 self.rnn
的 __init__
方法中定义。接下来,定义 self.classification
层,获取 RNN
的输出并在通过全连接层处理 RNN
输出后将其传递给 softmax
激活:
class Ocr(nn.Module):
def __init__(self, vocab):
super().__init__()
self.model = nn.Sequential(
BasicBlock( 1, 128),
BasicBlock(128, 128),
BasicBlock(128, 256, pool=(4,2)),
Reshape(-1, 256, 32),
Permute(2, 0, 1) # T, B, D
)
self.rnn = nn.Sequential(
nn.LSTM(256, 256, num_layers=2, dropout=0.2, bidirectional=True),
)
self.classification = nn.Sequential(
nn.Linear(512, vocab+1),
nn.LogSoftmax(-1),
)
定义前向计算方法 forward
:
def forward(self, x):
x = self.model(x)
x, lstm_states = self.rnn(x)
y = self.classification(x)
return y
在以上代码中,首先获取 CNN
输出,然后将其传递给 RNN
,以获取 lstm_states
和 RNN
输出 x
,最后通过分类层 (self.classification
) 输出并返回结果。
定义 CTC
损失函数:
def ctc(log_probs, target, input_lengths, target_lengths, blank=0):
loss = nn.CTCLoss(blank=blank, zero_infinity=True)
ctc_loss = loss(log_probs, target, input_lengths, target_lengths)
return ctc_loss
在以上代码中,利用 nn.CTCLoss
方法最小化 ctc_loss
,该方法将置信度矩阵 log_probs
(每个时间步的预测)、目标值 target
(真实标签)、输入长度 input_lengths
和目标长度 target_lengths
作为输入,返回 ctc_loss
值。
由于词汇表中有 53
个字符 (52
个字母加 1
个分隔符),因此输出中每张图像都有 53
个相关联的概率值输出。
(8) 定义函数在批数据上训练模型:
def train_batch(data, model, optimizer, criterion):
model.train()
imgs, targets, label_lens, input_lens, labels = data
optimizer.zero_grad()
preds = model(imgs)
loss = criterion(preds, targets, input_lens, label_lens)
loss.backward()
optimizer.step()
results = trn_ds.evaluate(model, imgs.to(device), labels)
return loss, results
(9) 定义函数在批数据上测试模型:
@torch.no_grad()
def validate_batch(data, model, criterion):
model.eval()
imgs, targets, label_lens, input_lens, labels = data
preds = model(imgs)
loss = criterion(preds, targets, input_lens, label_lens)
return loss, val_ds.evaluate(model, imgs.to(device), labels)
(10) 定义模型对象、优化器、损失函数和 epoch
数:
model = Ocr(len(vocab)).to(device)
criterion = ctc
optimizer = optim.AdamW(model.parameters(), lr=3e-3)
n_epochs = 60
(11) 训练模型:
trn_loss_epoch = []
val_loss_epoch = []
trn_char_acc_epoch = []
trn_word_acc_epoch = []
val_char_acc_epoch = []
val_word_acc_epoch = []
for ep in range(n_epochs):
N = len(trn_dl)
trn_loss_items = []
val_loss_items = []
trn_char_acc_items = []
trn_word_acc_items = []
val_char_acc_items = []
val_word_acc_items = []
for ix, data in enumerate(trn_dl):
pos = ep + (ix+1)/N
loss, results = train_batch(data, model, optimizer, criterion)
ca, wa = results['char-accuracy'], results['word-accuracy']
trn_loss_items.append(loss.item())
trn_char_acc_items.append(ca)
trn_word_acc_items.append(wa)
trn_loss_epoch.append(np.average(trn_loss_items))
trn_char_acc_epoch.append(np.average(trn_char_acc_items))
trn_word_acc_epoch.append(np.average(trn_word_acc_items))
N = len(val_dl)
for ix, data in enumerate(val_dl):
pos = ep + (ix+1)/N
loss, results = validate_batch(data, model, criterion)
ca, wa = results['char-accuracy'], results['word-accuracy']
val_loss_items.append(loss.item())
val_char_acc_items.append(ca)
val_word_acc_items.append(wa)
val_loss_epoch.append(np.average(val_loss_items))
val_char_acc_epoch.append(np.average(val_char_acc_items))
val_word_acc_epoch.append(np.average(val_word_acc_items))
print()
for jx in range(5):
img, label = val_ds.sample()
_img = torch.Tensor(val_ds.preprocess(img)[None,None]).to(device)
pred = model(_img)[:,0,:]
pred = trn_ds.decoder_chars(pred)
print(f'Pred: `{pred}` :: Truth: `{label}`')
print()
epochs = np.arange(n_epochs)+1
plt.plot(epochs, trn_word_acc_epoch, 'bo', label='Training accuracy')
plt.plot(epochs, val_word_acc_epoch, 'r-', label='Validation accuracy')
plt.title('Training and validation word accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid('off')
plt.show()
在上图中,我们可以看到该模型在验证数据集上的单词准确率约为 80%
。模型训练结束时的预测结果如下:
Pred: `sten` :: Truth: `step`
Pred: `falay` :: Truth: `today`
Pred: `admtroton` :: Truth: `administration`
Pred: `hrothoe` :: Truth: `brother`
Pred: `meonee` :: Truth: `response`
...
Pred: `speak` :: Truth: `speak`
Pred: `yard` :: Truth: `yard`
Pred: `executive` :: Truth: `executive`
Pred: `sit` :: Truth: `sit`
Pred: `since` :: Truth: `since`
小结
手写文本识别在许多领域有着广泛的应用,比如文档数字化、历史文献研究、表格处理等,可以帮助提高文档的可访问性、搜索性和可编辑性,节省人工时间和精力。在本节中,我们学习了如何结合使用 CNN
和 RNN
执行手写文本识别任务。
系列链接
PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——从零开始实现SSD目标检测
PyTorch深度学习实战(24)——使用U-Net架构进行图像分割
PyTorch深度学习实战(25)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(26)——多对象实例分割
PyTorch深度学习实战(27)——自编码器(Autoencoder)
PyTorch深度学习实战(28)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(29)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(30)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(31)——神经风格迁移
PyTorch深度学习实战(32)——Deepfakes
PyTorch深度学习实战(33)——生成对抗网络(Generative Adversarial Network, GAN)
PyTorch深度学习实战(34)——DCGAN详解与实现
PyTorch深度学习实战(35)——条件生成对抗网络(Conditional Generative Adversarial Network, CGAN)
PyTorch深度学习实战(36)——Pix2Pix详解与实现
PyTorch深度学习实战(37)——CycleGAN详解与实现
PyTorch深度学习实战(38)——StyleGAN详解与实现
PyTorch深度学习实战(39)——小样本学习(Few-shot Learning)
PyTorch深度学习实战(40)——零样本学习(Zero-Shot Learning)
PyTorch深度学习实战(41)——循环神经网络与长短期记忆网络
PyTorch深度学习实战(42)——图像字幕生成