PDF文件转换成其他格式常常是个大难题,大量的信息被锁在PDF里,AI应用无法直接访问。如果能把PDF文件或其对应的图像转换成结构化或半结构化的机器可读格式,那就能大大缓解这个问题,同时也能显著增强人工智能应用的知识库。
嘿,各位AI探险家们!今天我们将踏上了一段奇妙的PDF解析之旅,探索了那些不用OCR(光学字符识别)也能搞定PDF的神奇小模型。就像哈利·波特不用魔杖也能施法一样,这些小模型用神经网络直接“读懂”PDF,省去了繁琐的OCR步骤,简直是AI界的“无杖魔法”!
RAG(Retrieval-Augmented Generation)基建之PDF解析的“魔法”与“陷阱”
文章目录
概述
之前介绍的基于流水线的PDF解析方法主要使用OCR引擎进行文本识别。然而,这种方法计算成本高,对语言和文档类型的灵活性较差,且OCR错误可能影响后续任务。
因此,应该开发OCR-Free方法,如图1所示。这些方法不显式使用OCR来识别文本,而是使用神经网络隐式完成任务。本质上,这些方法采用端到端的方式,直接输出PDF解析结果。
OCR-Free vs. 流水线:谁更香?
从结构上看,OCR-Free方法比基于流水线的方法更简单。OCR-Free方法主要需要注意的方面是模型结构的设计和训练数据的构建。OCR-Free方法虽然一步到位,避免了中间步骤的“损耗”,但它的训练和推理速度有点慢,像是一辆豪华跑车,虽然性能强大,但油耗高。而基于流水线的方法则像是一辆经济型小车,虽然步骤多,但每个模块都很轻量,适合大规模部署。
接下来,我们将介绍几种具有代表性的OCR-Free小型模型PDF解析框架:
-
- Donut:PDF解析界的“甜甜圈”
Donut这个小家伙,别看它名字甜,干起活来可是一点都不含糊。它用Swin Transformer当“眼睛”,BART当“嘴巴”,直接把PDF图像“吃”进去,吐出一串JSON格式的“甜点”。不用OCR,全靠神经网络,简直是PDF解析界的“甜品大师”!
- Donut:PDF解析界的“甜甜圈”
-
- Nougat:PDF解析界的“牛轧糖”
Nougat,名字听起来就很有嚼劲,它的绝活是把PDF图像变成Markdown。它特别擅长处理复杂的公式和表格,简直是PDF解析界的“糖果工匠”。不过,它的生成速度有点慢,像牛轧糖一样,嚼起来需要点耐心。
- Nougat:PDF解析界的“牛轧糖”
-
- ** Pix2Struct:PDF解析界的“像素魔法师”**
Pix2Struct是个视觉语言理解的高手,它的任务是从屏蔽的网页截图中预测HTML解析。它不仅能处理PDF,还能搞定网页截图,简直是多才多艺的“像素魔法师”。不过,它的训练数据来自网页,可能会带来一些“有害内容”,使用时得小心点。
- ** Pix2Struct:PDF解析界的“像素魔法师”**
详细介绍
Donut
如图2所示,Donut是一个端到端模型,旨在全面理解文档图像。其架构简单,由基于Transformer的视觉编码器和文本解码器模块组成。
Donut不依赖任何与OCR相关的模块,而是使用视觉编码器从文档图像中提取特征,并直接使用文本解码器生成token序列。输出序列可以转换为JSON等结构化格式。
代码如下:
class DonutModel(PreTrainedModel):
r"""
Donut: 一个端到端的OCR-Free文档理解Transformer。
编码器将输入的文档图像映射为一组嵌入,
解码器预测所需的token序列,可以将其转换为结构化格式,
给定提示和编码器输出的嵌入
"""
config_class = DonutConfig
base_model_prefix = "donut"
def __init__(self, config: DonutConfig):
super().__init__(config)
self.config = config
self.encoder = SwinEncoder(
input_size=self.config.input_size,
align_long_axis=self.config.align_long_axis,
window_size=self.config.window_size,
encoder_layer=self.config.encoder_layer,
name_or_path=self.config.name_or_path,
)
self.decoder = BARTDecoder(
max_position_embeddings=self.config.max_position_embeddings,
decoder_layer=self.config.decoder_layer,
name_or_path=self.config.name_or_path,
)
def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor, decoder_labels: torch.Tensor):
"""
给定输入图像和所需的token序列计算损失,
模型将以教师强制的方式进行训练
参数:
image_tensors: (batch_size, num_channels, height, width)
decoder_input_ids: (batch_size, sequence_length, embedding_dim)
decode_labels: (batch_size, sequence_length)
"""
encoder_outputs = self.encoder(image_tensors)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs,
labels=decoder_labels,
)
return decoder_outputs
...
...
编码器
Donut使用Swin-Transformer作为图像编码器,因为它在初步的文档解析研究中表现出色。该图像编码器将输入的文档图像转换为一组高维嵌入。这些嵌入将作为文本解码器的输入。
对应代码如下:
class SwinEncoder(nn.Module):
r"""
基于SwinTransformer的Donut编码器
使用预训练的SwinTransformer设置初始权重和配置,
然后修改详细配置作为Donut编码器
参数:
input_size: 输入图像大小(宽度,高度)
align_long_axis: 如果高度大于宽度,是否旋转图像
window_size: SwinTransformer的窗口大小(=patch大小)
encoder_layer: SwinTransformer编码器的层数
name_or_path: 预训练模型名称,要么在huggingface.co.注册,要么保存在本地。
否则,将设置为`swin_base_patch4_window12_384`(使用`timm`)。
"""
def __init__(
self,
input_size: List[int],
align_long_axis: bool,
window_size: int,
encoder_layer: List[int],
name_or_path: Union[str, bytes, os.PathLike] = None,
):
super().__init__()
self.input_size = input_size
self.align_long_axis = align_long_axis
self.window_size = window_size
self.encoder_layer = encoder_layer
self.to_tensor = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
]
)
self.model = SwinTransformer(
img_size=self.input_size,
depths=self.encoder_layer,
window_size=self.window_size,
patch_size=4,
embed_dim=128,
num_heads=[4, 8, 16, 32],
num_classes=0,
)
self.model.norm = None
# 使用swin初始化权重
if not name_or_path:
swin_state_dict = timm.create_model("swin_base_patch4_window12_384", pretrained=True).state_dict()
new_swin_state_dict = self.model.state_dict()
for x in new_swin_state_dict:
if x.endswith("relative_position_index") or x.endswith