【图像处理基石】如何实现一个车辆检测算法?

发布于:2025-07-27 ⋅ 阅读:(17) ⋅ 点赞:(0)

在这里插入图片描述

基于AI的车牌检测和识别算法

问题描述、应用场景与难点

问题描述

车牌检测和识别是计算机视觉领域的一个特定任务,主要包含两个核心步骤:

  1. 车牌检测:从图像中准确定位车牌的位置和区域
  2. 车牌识别:对检测到的车牌区域进行字符识别,转换为文本信息

整个流程需要处理图像输入,输出结构化的车牌文本信息,实现从图像到文字的转换。

应用场景

  • 智能交通监控系统(违章识别、交通流量统计)
  • 停车场自动化管理(入场出库自动登记、计费)
  • 高速公路收费站ETC辅助系统
  • 车辆防盗与追踪系统
  • 城市道路规划与交通流分析
  • 小区、园区等封闭区域的车辆管理

问题难点

  • 车牌在图像中占比小,特征不明显
  • 光照条件复杂(强光、逆光、夜间等)
  • 车牌存在污损、遮挡、模糊等情况
  • 车辆运动造成的图像模糊
  • 不同地区车牌样式、字符集差异大
  • 复杂背景干扰(相似颜色、纹理干扰)
  • 车牌可能存在倾斜、变形等情况

PyTorch实现车牌检测和识别算法

下面实现一个两阶段的车牌检测与识别系统:首先使用改进的YOLOv5进行车牌检测,然后使用CNN-LSTM网络进行字符识别。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import os
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 1. 车牌检测模型 - 基于YOLOv5的简化版本
class YOLOv5LicensePlateDetector(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        # 简化的YOLOv5骨干网络
        self.backbone = nn.Sequential(
            # 输入: 3x416x416
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),  # 32x208x208
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),  # 64x104x104
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, 64, kernel_size=1, stride=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),  # 128x52x52
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 128, kernel_size=1, stride=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2),  # 256x26x26
            
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512, 256, kernel_size=1, stride=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512, 256, kernel_size=1, stride=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
        )
        
        # 检测头 - 输出边界框和类别
        # 每个检测框包含: x, y, w, h, confidence, class_prob
        self.head = nn.Conv2d(512, (5 + num_classes) * 3, kernel_size=1, stride=1)
        self.num_classes = num_classes
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)  # (batch_size, (5 + num_classes)*3, 26, 26)
        
        # 调整输出格式 [batch_size, num_anchors, height, width, 5 + num_classes]
        batch_size = x.shape[0]
        num_anchors = 3
        grid_size = x.shape[2]
        
        x = x.view(batch_size, num_anchors, 5 + self.num_classes, grid_size, grid_size)
        x = x.permute(0, 1, 3, 4, 2).contiguous()  # [batch, anchors, grid_h, grid_w, 5 + classes]
        
        return x

# 2. 车牌识别模型 - CNN + LSTM架构
class LicensePlateRecognizer(nn.Module):
    def __init__(self, num_chars, img_height=32, img_width=128):
        super().__init__()
        self.num_chars = num_chars
        
        # CNN特征提取部分
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 32x16x64
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 64x8x32
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 128x4x16
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)   # 256x2x8
        )
        
        # 计算CNN输出特征的尺寸
        self.feature_height = img_height // (2**4)  # 32 / 16 = 2
        self.feature_width = img_width // (2**4)    # 128 / 16 = 8
        
        # LSTM序列识别部分
        self.lstm = nn.LSTM(
            input_size=256 * self.feature_height,  # 256 * 2 = 512
            hidden_size=128,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )
        
        # 输出层
        self.fc = nn.Linear(256, num_chars + 1)  # +1 是为了CTC的空白字符
    
    def forward(self, x):
        # x: [batch_size, 3, height, width]
        batch_size = x.size(0)
        
        # CNN特征提取
        x = self.cnn(x)  # [batch_size, 256, 2, 8]
        
        # 调整形状以适应LSTM输入
        x = x.permute(0, 3, 1, 2)  # [batch_size, width, 256, 2]
        x = x.view(batch_size, -1, 256 * self.feature_height)  # [batch_size, 8, 512]
        
        # LSTM处理
        x, _ = self.lstm(x)  # [batch_size, 8, 256] (双向所以是128*2)
        
        # 输出层
        x = self.fc(x)  # [batch_size, 8, num_chars + 1]
        
        # 转置为CTC Loss所需的格式 [seq_len, batch_size, num_classes]
        x = x.permute(1, 0, 2)  # [8, batch_size, num_chars + 1]
        
        return x

# 3. 数据集类
class LicensePlateDataset(Dataset):
    def __init__(self, image_dir, label_file, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.samples = []
        
        # 读取标签文件
        # 标签文件格式: 图像名 x1 y1 x2 y2 车牌字符
        with open(label_file, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 6:
                    continue
                img_name = parts[0]
                bbox = tuple(map(int, parts[1:5]))
                plate_chars = parts[5]
                self.samples.append((img_name, bbox, plate_chars))
        
        # 字符集定义 (以中国车牌为例)
        self.char_set = "0123456789ABCDEFGHJKLMNPQRSTUVWXYZ京津冀晋蒙辽吉黑沪苏浙皖闽赣鲁豫鄂湘粤桂琼渝川贵云藏陕甘青宁新"
        self.char_to_idx = {char: i+1 for i, char in enumerate(self.char_set)}  # 0留给空白字符
        self.idx_to_char = {i+1: char for i, char in enumerate(self.char_set)}
        self.idx_to_char[0] = ""  # 空白字符
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_name, bbox, plate_chars = self.samples[idx]
        img_path = os.path.join(self.image_dir, img_name)
        
        # 读取图像
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # 提取车牌区域
        x1, y1, x2, y2 = bbox
        plate_img = img[y1:y2, x1:x2]
        
        # 调整大小
        plate_img = cv2.resize(plate_img, (128, 32))
        
        # 转换字符为索引
        label = [self.char_to_idx[char] for char in plate_chars if char in self.char_to_idx]
        
        # 应用变换
        if self.transform:
            plate_img = self.transform(plate_img)
        
        # 返回原图、车牌图像、边界框和标签
        return {
            'original_image': transforms.ToTensor()(img),
            'plate_image': plate_img,
            'bbox': torch.tensor(bbox, dtype=torch.float32),
            'label': torch.tensor(label, dtype=torch.long),
            'label_length': torch.tensor(len(label), dtype=torch.long)
        }

# 4. 训练函数
def train_model(detector, recognizer, train_loader, val_loader, epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    detector.to(device)
    recognizer.to(device)
    
    # 定义损失函数和优化器
    det_criterion = nn.MSELoss()  # 简化版,实际YOLO使用更复杂的损失
    rec_criterion = nn.CTCLoss(blank=0)
    det_optimizer = torch.optim.Adam(detector.parameters(), lr=1e-4)
    rec_optimizer = torch.optim.Adam(recognizer.parameters(), lr=1e-4)
    
    # 训练循环
    for epoch in range(epochs):
        detector.train()
        recognizer.train()
        train_loss = 0.0
        
        for batch in train_loader:
            original_images = batch['original_image'].to(device)
            plate_images = batch['plate_image'].to(device)
            bboxes = batch['bbox'].to(device)
            labels = batch['label']
            label_lengths = batch['label_length']
            
            # 清零梯度
            det_optimizer.zero_grad()
            rec_optimizer.zero_grad()
            
            # 检测模型前向传播
            det_outputs = detector(original_images)
            # 简化的检测损失计算
            # 实际中需要根据YOLO的输出格式计算损失
            det_loss = det_criterion(det_outputs.mean(dim=[1,2,3]), bboxes)
            
            # 识别模型前向传播
            rec_outputs = recognizer(plate_images)
            batch_size = rec_outputs.size(1)
            input_lengths = torch.full((batch_size,), rec_outputs.size(0), dtype=torch.long)
            
            # 计算CTC损失
            # 需要将标签展平
            flat_labels = torch.cat([label for label in labels])
            rec_loss = rec_criterion(rec_outputs.log_softmax(2), flat_labels, 
                                    input_lengths, label_lengths)
            
            # 总损失
            total_loss = det_loss + rec_loss
            total_loss.backward()
            
            # 更新参数
            det_optimizer.step()
            rec_optimizer.step()
            
            train_loss += total_loss.item()
        
        # 计算平均训练损失
        avg_train_loss = train_loss / len(train_loader)
        
        # 验证
        detector.eval()
        recognizer.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                original_images = batch['original_image'].to(device)
                plate_images = batch['plate_image'].to(device)
                bboxes = batch['bbox'].to(device)
                labels = batch['label']
                label_lengths = batch['label_length']
                
                # 检测模型
                det_outputs = detector(original_images)
                det_loss = det_criterion(det_outputs.mean(dim=[1,2,3]), bboxes)
                
                # 识别模型
                rec_outputs = recognizer(plate_images)
                batch_size = rec_outputs.size(1)
                input_lengths = torch.full((batch_size,), rec_outputs.size(0), dtype=torch.long)
                flat_labels = torch.cat([label for label in labels])
                rec_loss = rec_criterion(rec_outputs.log_softmax(2), flat_labels, 
                                        input_lengths, label_lengths)
                
                val_loss += (det_loss + rec_loss).item()
        
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
    
    return detector, recognizer

# 5. 推理函数
def predict_license_plate(detector, recognizer, image_path, dataset, threshold=0.5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    detector.to(device)
    recognizer.to(device)
    detector.eval()
    recognizer.eval()
    
    # 图像预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_tensor = transform(img_rgb).unsqueeze(0).to(device)
    
    # 检测车牌
    with torch.no_grad():
        det_outputs = detector(img_tensor)
        # 简化的后处理,实际需要更复杂的解码
        # 这里假设我们直接得到了边界框
        # 在实际应用中,需要从YOLO输出中解码出边界框和置信度
        bbox = [50, 50, 200, 100]  # 示例值,实际应从det_outputs计算
    
    # 提取车牌区域
    x1, y1, x2, y2 = map(int, bbox)
    plate_img = img_rgb[y1:y2, x1:x2]
    plate_img = cv2.resize(plate_img, (128, 32))
    plate_tensor = transform(plate_img).unsqueeze(0).to(device)
    
    # 识别车牌字符
    with torch.no_grad():
        rec_outputs = recognizer(plate_tensor)
        # 应用softmax
        probs = F.softmax(rec_outputs, dim=2)
        # 取概率最大的字符索引
        _, preds = torch.max(probs, 2)
        preds = preds.squeeze().cpu().numpy()
    
    # 解码预测结果(去除空白字符和重复字符)
    result = []
    prev = -1
    for p in preds:
        if p != prev and p != 0:
            result.append(dataset.idx_to_char[p])
        prev = p
    plate_text = ''.join(result)
    
    # 可视化结果
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    plt.imshow(img_rgb)
    plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r-')
    plt.title('Detected License Plate')
    
    plt.subplot(122)
    plt.imshow(plate_img)
    plt.title(f'Recognized: {plate_text}')
    plt.show()
    
    return bbox, plate_text

# 主函数
def main():
    # 数据路径(请替换为实际路径)
    train_image_dir = 'train_images/'
    val_image_dir = 'val_images/'
    train_label_file = 'train_labels.txt'
    val_label_file = 'val_labels.txt'
    
    # 数据变换
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomAffine(degrees=5, translate=(0.1, 0.1)),  # 数据增强
        transforms.ColorJitter(brightness=0.2, contrast=0.2)
    ])
    
    # 创建数据集和数据加载器
    train_dataset = LicensePlateDataset(train_image_dir, train_label_file, transform)
    val_dataset = LicensePlateDataset(val_image_dir, val_label_file, transform)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
    
    # 初始化模型
    detector = YOLOv5LicensePlateDetector()
    recognizer = LicensePlateRecognizer(num_chars=len(train_dataset.char_set))
    
    # 训练模型
    trained_detector, trained_recognizer = train_model(
        detector, recognizer, train_loader, val_loader, epochs=10
    )
    
    # 保存模型
    torch.save(trained_detector.state_dict(), 'license_plate_detector.pth')
    torch.save(trained_recognizer.state_dict(), 'license_plate_recognizer.pth')
    
    # 测试推理
    test_image_path = 'test_image.jpg'  # 替换为测试图像路径
    bbox, plate_text = predict_license_plate(
        trained_detector, trained_recognizer, test_image_path, train_dataset
    )
    print(f'Detected License Plate: {plate_text} at position {bbox}')

if __name__ == "__main__":
    main()

代码说明

上述实现包含了一个完整的车牌检测与识别系统,主要分为以下几个部分:

  1. 车牌检测模型:基于简化的YOLOv5架构,用于从原始图像中定位车牌位置。模型输出车牌的边界框坐标和置信度。

  2. 车牌识别模型:采用CNN+LSTM架构,CNN用于提取车牌图像的特征,LSTM用于处理序列信息,实现对车牌字符的识别。使用CTC损失函数处理不定长字符序列的识别问题。

  3. 数据集类:自定义数据集类用于加载和预处理数据,支持读取图像、解析标签、提取车牌区域和字符编码等功能。

  4. 训练与推理函数:实现了模型的训练流程和推理功能,支持模型保存和结果可视化。

数据集需求与准备方法

数据集需求

一个高质量的车牌检测与识别数据集应具备以下特点:

  1. 数据规模:至少包含10,000张以上的车辆图像,涵盖不同场景和条件
  2. 标注信息
    • 车牌位置的边界框坐标
    • 车牌上的字符内容(精确到每个字符)
  3. 多样性
    • 不同类型的车辆(轿车、货车、摩托车等)
    • 不同光照条件(白天、夜晚、阴天、逆光等)
    • 不同角度和姿态(正面、侧面、倾斜等)
    • 不同天气条件(晴天、雨天、雪天等)
    • 不同背景环境(城市道路、高速公路、停车场等)
    • 不同的车牌状态(干净、污损、遮挡等)

常用公开数据集

  1. CCPD (Chinese City Parking Dataset):包含大量中国城市停车场的车牌图像,标注详细
  2. ApolloScape:包含各种交通场景的图像,其中有车牌标注
  3. CALTECH LPR Dataset:包含美国车牌的数据集
  4. SSIG-Segmented License Plate Dataset:包含多种国家车牌的数据集

数据集准备步骤

  1. 数据收集

    • 收集公开数据集
    • 自行拍摄补充特定场景数据
    • 注意遵守数据隐私法规
  2. 数据清洗

    • 去除模糊、过暗或过亮的低质量图像
    • 检查并修正错误标注
    • 去除重复样本
  3. 数据标注

    • 使用标注工具(如LabelImg、VGG Image Annotator等)标注车牌位置
    • 标注车牌上的字符内容
    • 建立统一的标注格式
  4. 数据增强

    • 几何变换:旋转、缩放、裁剪、平移、翻转
    • 颜色变换:亮度、对比度、饱和度调整
    • 添加噪声、模糊处理模拟真实场景
    • 车牌遮挡模拟
  5. 数据划分

    • 按7:2:1的比例划分为训练集、验证集和测试集
    • 确保各集合的数据分布一致

相关研究最新进展

近年来,车牌检测与识别领域的研究取得了显著进展:

  1. 端到端方法:传统的两阶段方法(先检测后识别)逐渐被端到端模型取代,如YOLO-LPR、E2E-LPR等,直接从图像输出车牌字符,简化了流程并提高了精度。

  2. Transformer架构应用:基于Transformer的模型(如DETR衍生模型)在车牌检测任务中表现出色,能够更好地处理复杂背景和小目标检测问题。

  3. 小样本学习:针对特定区域或特殊车牌类型的数据稀缺问题,小样本学习方法被引入,通过元学习等技术提高模型的泛化能力。

  4. 多模态融合:结合可见光图像和红外图像的多模态方法,提高了夜间和恶劣天气条件下的识别性能。

  5. 实时性优化:通过模型压缩、知识蒸馏等技术,使车牌识别模型能够在嵌入式设备和边缘计算平台上实时运行,满足实际应用需求。

  6. 鲁棒性提升:针对车牌污损、遮挡等问题,研究人员提出了基于生成对抗网络(GAN)的数据增强方法,以及注意力机制来聚焦关键字符区域。

  7. 跨域适应性:研究如何使模型在不同国家/地区的车牌样式之间进行迁移,减少对特定数据集的依赖。

这些进展使得车牌检测与识别系统在实际应用中的准确率和鲁棒性不断提高,为智能交通系统的发展提供了有力支持。


网站公告

今日签到

点亮在社区的每一天
去签到