RAG基建之PDF解析的“无OCR”魔法之旅

发布于:2025-03-30 ⋅ 阅读:(27) ⋅ 点赞:(0)

PDF文件转换成其他格式常常是个大难题,大量的信息被锁在PDF里,AI应用无法直接访问。如果能把PDF文件或其对应的图像转换成结构化或半结构化的机器可读格式,那就能大大缓解这个问题,同时也能显著增强人工智能应用的知识库。

嘿,各位AI探险家们!今天我们将踏上了一段奇妙的PDF解析之旅,探索了那些不用OCR(光学字符识别)也能搞定PDF的神奇小模型。就像哈利·波特不用魔杖也能施法一样,这些小模型用神经网络直接“读懂”PDF,省去了繁琐的OCR步骤,简直是AI界的“无杖魔法”!

RAG(Retrieval-Augmented Generation)基建之PDF解析的“魔法”与“陷阱”

概述

之前介绍的基于流水线的PDF解析方法主要使用OCR引擎进行文本识别。然而,这种方法计算成本高,对语言和文档类型的灵活性较差,且OCR错误可能影响后续任务。

因此,应该开发OCR-Free方法,如图1所示。这些方法不显式使用OCR来识别文本,而是使用神经网络隐式完成任务。本质上,这些方法采用端到端的方式,直接输出PDF解析结果。在这里插入图片描述
OCR-Free vs. 流水线:谁更香?
从结构上看,OCR-Free方法比基于流水线的方法更简单。OCR-Free方法主要需要注意的方面是模型结构的设计和训练数据的构建。OCR-Free方法虽然一步到位,避免了中间步骤的“损耗”,但它的训练和推理速度有点慢,像是一辆豪华跑车,虽然性能强大,但油耗高。而基于流水线的方法则像是一辆经济型小车,虽然步骤多,但每个模块都很轻量,适合大规模部署。

接下来,我们将介绍几种具有代表性的OCR-Free小型模型PDF解析框架:

    1. Donut:PDF解析界的“甜甜圈”
      Donut这个小家伙,别看它名字甜,干起活来可是一点都不含糊。它用Swin Transformer当“眼睛”,BART当“嘴巴”,直接把PDF图像“吃”进去,吐出一串JSON格式的“甜点”。不用OCR,全靠神经网络,简直是PDF解析界的“甜品大师”!
    1. Nougat:PDF解析界的“牛轧糖”
      Nougat,名字听起来就很有嚼劲,它的绝活是把PDF图像变成Markdown。它特别擅长处理复杂的公式和表格,简直是PDF解析界的“糖果工匠”。不过,它的生成速度有点慢,像牛轧糖一样,嚼起来需要点耐心。
    1. ** Pix2Struct:PDF解析界的“像素魔法师”**
      Pix2Struct是个视觉语言理解的高手,它的任务是从屏蔽的网页截图中预测HTML解析。它不仅能处理PDF,还能搞定网页截图,简直是多才多艺的“像素魔法师”。不过,它的训练数据来自网页,可能会带来一些“有害内容”,使用时得小心点。

详细介绍

Donut

如图2所示,Donut是一个端到端模型,旨在全面理解文档图像。其架构简单,由基于Transformer的视觉编码器和文本解码器模块组成。在这里插入图片描述

Donut不依赖任何与OCR相关的模块,而是使用视觉编码器从文档图像中提取特征,并直接使用文本解码器生成token序列。输出序列可以转换为JSON等结构化格式。

代码如下:

class DonutModel(PreTrainedModel):
    r"""
    Donut: 一个端到端的OCR-Free文档理解Transformer。
    编码器将输入的文档图像映射为一组嵌入,
    解码器预测所需的token序列,可以将其转换为结构化格式,
    给定提示和编码器输出的嵌入
    """
    config_class = DonutConfig
    base_model_prefix = "donut"

    def __init__(self, config: DonutConfig):
        super().__init__(config)
        self.config = config
        self.encoder = SwinEncoder(
            input_size=self.config.input_size,
            align_long_axis=self.config.align_long_axis,
            window_size=self.config.window_size,
            encoder_layer=self.config.encoder_layer,
            name_or_path=self.config.name_or_path,
        )
        self.decoder = BARTDecoder(
            max_position_embeddings=self.config.max_position_embeddings,
            decoder_layer=self.config.decoder_layer,
            name_or_path=self.config.name_or_path,
        )

    def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor, decoder_labels: torch.Tensor):
        """
        给定输入图像和所需的token序列计算损失,
        模型将以教师强制的方式进行训练

        参数:
            image_tensors: (batch_size, num_channels, height, width)
            decoder_input_ids: (batch_size, sequence_length, embedding_dim)
            decode_labels: (batch_size, sequence_length)
        """
        encoder_outputs = self.encoder(image_tensors)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_outputs,
            labels=decoder_labels,
        )
        return decoder_outputs
    ...
    ...

编码器

Donut使用Swin-Transformer作为图像编码器,因为它在初步的文档解析研究中表现出色。该图像编码器将输入的文档图像转换为一组高维嵌入。这些嵌入将作为文本解码器的输入。

对应代码如下:

class SwinEncoder(nn.Module):
    r"""
    基于SwinTransformerDonut编码器
    使用预训练的SwinTransformer设置初始权重和配置,
    然后修改详细配置作为Donut编码器

    参数:
        input_size: 输入图像大小(宽度,高度)
        align_long_axis: 如果高度大于宽度,是否旋转图像
        window_size: SwinTransformer的窗口大小(=patch大小)
        encoder_layer: SwinTransformer编码器的层数
        name_or_path: 预训练模型名称,要么在huggingface.co.注册,要么保存在本地。
                      否则,将设置为`swin_base_patch4_window12_384`(使用`timm`)。
    """

    def __init__(
        self,
        input_size: List[int],
        align_long_axis: bool,
        window_size: int,
        encoder_layer: List[int],
        name_or_path: Union[str, bytes, os.PathLike] = None,
    ):
        super().__init__()
        self.input_size = input_size
        self.align_long_axis = align_long_axis
        self.window_size = window_size
        self.encoder_layer = encoder_layer

        self.to_tensor = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
            ]
        )

        self.model = SwinTransformer(
            img_size=self.input_size,
            depths=self.encoder_layer,
            window_size=self.window_size,
            patch_size=4,
            embed_dim=128,
            num_heads=[4, 8, 16, 32],
            num_classes=0,
        )
        self.model.norm = None

        # 使用swin初始化权重
        if not name_or_path:
            swin_state_dict = timm.create_model("swin_base_patch4_window12_384", pretrained=True).state_dict()
            new_swin_state_dict = self.model.state_dict()
            for x in new_swin_state_dict:
                if x.endswith("relative_position_index") or x.endswith