深度学习篇---PaddleDetection&PaddleOCR

发布于:2025-03-28 ⋅ 阅读:(29) ⋅ 点赞:(0)


前言

本文简单介绍了PaddleDetection和PaddleOCR相结合的示例代码,通过两个PaddlePaddle框架下的工具包结合使用同时达到图像识别和文本识别的功能。


1.代码

import cv2
import re
import serial
import sqlite3
from datetime import datetime
from paddledetection.deploy.python.infer import Detector
from paddleocr import PaddleOCR

# ========== 配置区域 ==========
SERIAL_PORT = '/dev/ttyUSB0'  # 串口设备
BAUDRATE = 9600               # 波特率
DB_NAME = 'express.db'        # 数据库名称
DET_MODEL_DIR = './jd_sf_model/'  # 检测模型路径
# =============================

class ExpressInfoProcessor:
    def __init__(self):
        # 初始化检测模型
        self.detector = Detector(
            model_dir=DET_MODEL_DIR,
            device='GPU',
            threshold=0.6
        )
        
        # 初始化OCR模型
        self.ocr = PaddleOCR(use_angle_cls=True, lang='ch')
        
        # 初始化串口
        self.ser = serial.Serial(SERIAL_PORT, BAUDRATE)
        
        # 初始化数据库
        self.conn = sqlite3.connect(DB_NAME)
        self._init_db()

    def _init_db(self):
        """初始化数据库表结构"""
        cursor = self.conn.cursor()
        cursor.execute('''CREATE TABLE IF NOT EXISTS express_info
                       (id INTEGER PRIMARY KEY AUTOINCREMENT,
                        company TEXT,
                        track_no TEXT,
                        name TEXT,
                        phone TEXT,
                        create_time TIMESTAMP)''')
        self.conn.commit()

    def process_image(self, img_path):
        """处理图像的主流程"""
        # 读取图像
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(f"Image {img_path} not found")

        # 检测快递公司
        company = self._detect_company(img)
        if not company:
            print("未识别到有效快递公司")
            return

        # OCR文字识别
        ocr_result = self.ocr.ocr(img, cls=True)
        text_list = [line[1][0] for line in ocr_result]

        # 信息提取
        info = self._extract_info(text_list)
        
        # 发送串口指令
        self._send_serial(company)
        
        # 存储到数据库
        self._save_to_db(company, info)

    def _detect_company(self, img):
        """检测快递公司"""
        results = self.detector.predict([img], visual=False)
        
        # 解析检测结果(假设0:京东,1:顺丰)
        for result in results:
            if len(result['boxes']) > 0:
                class_id = int(result['boxes'][0][0])
                score = result['boxes'][0][1]
                if score > 0.6:
                    return '京东' if class_id == 0 else '顺丰'
        return None

    def _extract_info(self, text_list):
        """从OCR结果提取结构化信息"""
        return {
            'track_no': self._find_track_no(text_list),
            'name': self._find_name(text_list),
            'phone': self._find_phone(text_list)
        }

    def _find_track_no(self, texts):
        """查找快递单号"""
        # 优先查找包含关键字的条目
        for text in texts:
            if any(kw in text for kw in ['单号', '快递单号', '运单号']):
                match = re.search(r'\d{10,20}', text)
                if match:
                    return match.group()
        
        # 全局搜索长数字
        for text in texts:
            match = re.search(r'\d{12,20}', text)
            if match:
                return match.group()
        return ''

    def _find_phone(self, texts):
        """查找电话号码"""
        for text in texts:
            match = re.search(r'(1[3-9]\d{9})', text)
            if match:
                return match.group()
        return ''

    def _find_name(self, texts):
        """查找收件人姓名"""
        # 查找包含关键字的条目
        for text in texts:
            if any(kw in text for kw in ['收件人', '姓名', '收货人']):
                parts = re.split(r'[::]', text)
                if len(parts) > 1:
                    name = parts[-1].strip()
                    if re.match(r'^[\u4e00-\u9fa5]{2,4}$', name):
                        return name
        
        # 匹配纯中文姓名
        for text in texts:
            if re.match(r'^[\u4e00-\u9fa5]{2,4}$', text):
                return text
        return ''

    def _send_serial(self, company):
        """发送串口指令"""
        cmd = '1' if company == '京东' else '2'
        self.ser.write(cmd.encode())
        print(f"已发送指令:{cmd}")

    def _save_to_db(self, company, info):
        """保存到数据库"""
        cursor = self.conn.cursor()
        cursor.execute('''INSERT INTO express_info 
                       (company, track_no, name, phone, create_time)
                       VALUES (?, ?, ?, ?, ?)''',
                       (company,
                        info['track_no'],
                        info['name'],
                        info['phone'],
                        datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
        self.conn.commit()
        print("数据已存储")

    def __del__(self):
        """资源清理"""
        self.ser.close()
        self.conn.close()

if __name__ == "__main__":
    processor = ExpressInfoProcessor()
    processor.process_image("test.jpg")
    

2.代码介绍

这段代码是一个用于处理快递信息的Python程序,具体功能包括读取图像中的快递信息,识别快递公司、提取快递单号、收件人姓名和联系电话,然后通过串口发送指令,并将这些信息存储到SQLite数据库中。下面是对代码中各部分的详细介绍:

2.1 导入模块

代码一开始导入了需要使用的模块,包括OpenCV(cv2)用于图像处理、正则表达式模块(re)用于文本匹配串口通信模块(serial)SQLite数据库模块(sqlite3)****、datetime模块用于处理时间PaddlePaddle的物体检测模型(Detector)和OCR模型(PaddleOCR)

2.2 配置区域

定义了一些配置常量,如串口设备、波特率、数据库名称、检测模型路径等。

2.3 ExpressInfoProcessor类

这是主要的处理类,包含以下方法和功能:

__init__方法:初始化函数,初始化物体检测模型、OCR模型、串口和数据库连接

_init_db方法:初始化数据库表结构,如果表不存在则创建。

process_image方法:处理图像的主要流程,包括读取图像、检测快递公司、OCR文字识别、提取信息、发送串口指令和存储到数据库。

_detect_company方法:检测快递公司,使用物体检测模型判断是京东还是顺丰。

_extract_info方法:从OCR结果中提取结构化信息,包括快递单号、姓名和电话号码。

_find_track_no、_find_phone、_find_name方法:分别用于查找快递单号、电话号码和姓名。

_send_serial方法:发送串口指令,根据公司类型发送不同的指令。

_save_to_db方法:将信息保存到数据库中。

__del__方法:资源清理方法,在对象销毁时关闭串口和数据库连接。

2.4 主程序

在if __name__ == "__main__"下初始化ExpressInfoProcessor对象,并调用process_image方法处理名为"test.jpg"的图像。

3.使用说明

3.1环境准备

pip install paddlepaddle paddleocr paddledetection serial pyserial opencv-python

3.2模型准备

  1. 训练京东/顺丰图标检测模型(使用PP-YOLO等算法)
  2. 将训练好的模型保存到jd_sf_model目录
  3. 目录应包含:model.pdmodel, model.pdiparams, infer_cfg.yml

3.3数据库初始化

代码首次运行时会自动创建SQLite数据库和表结构

3.4串口配置

  1. 根据实际硬件修改SERIAL_PORTBAUDRATE参数
  2. Linux系统查看端口:ls /dev/tty*
  3. Windows系统端口通常为COMx格式

3.5信息提取优化

  1. 可根据实际快递单样式调整正则表达式匹配规则
  2. 添加更多关键字匹配规则提高准确性

3.6注意事项

  1. 检测模型需要自行训练并放置到指定目录
  2. 实际快递单的OCR识别效果取决于图像质量
  3. 信息提取规则可能需要根据具体快递单样式调整
  4. 串口通信需要正确配置端口和波特率
  5. 数据库文件会自动生成在当前目录下