OCR图片矫正、表格检测及裁剪综合实践

发布于:2024-08-08 ⋅ 阅读:(86) ⋅ 点赞:(0)

问题描述

实际工程中,我们经常需要对图片进行预处理,比如:

1、图片是倾斜的

2、图片背景需要处理掉

3、图片的公章需要剔除

4、图片过暗,过亮

5、图片表格检测

6、图片表格版面分析

。。。。。。等等各种情况。

结果展示

本文以表格图片为例,介绍如何进行矫正、表格检测及裁剪保存图片。

原始图片

矫正之后

表格检测

裁剪之后

代码详解

图片矫正

通过多次旋转计算最佳旋转角度并应用旋转矩阵矫正图片

#coding=utf-8
import cv2
import numpy as np
def rotate_image(image, angle):
    (h, w) = image.shape[: 2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    corrected = cv2.warpAffine(image, M, (w, h), flags = cv2.INTER_CUBIC, \
        borderMode = cv2.BORDER_REPLICATE)
    return corrected

def determine_score(arr):
     histogram = np.sum(arr, axis = 2, dtype = float)
     score = np.sum((histogram[..., 1 :] - histogram[..., : -1]) ** 2, \
        axis = 1, dtype = float)
     return score

def correct_skew(image, delta = 0.05, limit = 10):
     thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + \
        cv2.THRESH_OTSU)[1]
     angles = np.arange(-limit, limit + delta, delta)
     img_stack = np.stack([rotate_image(thresh, angle) for angle \
        in angles], axis = 0)
     scores = determine_score(img_stack)
     best_angle = angles[np.argmax(scores)]
     corrected = rotate_image(image, best_angle)
     return best_angle, corrected
if __name__ == "__main__":
    batch_folder = r'D:\temp\pics'
    out_folder = r'D:\temp\picsout/'
    for root, dirs, files in os.walk(batch_folder):
        for file in files:
            file_path = os.path.join(root, file)
            file_path = file_path.replace('\\', '/')
            img = cv2.imread(file_path, 0)
            angle, corrected = correct_skew(img)
            print(angle,file_path)
            cv2.imwrite(out_folder + file_path.split('/')[-1], corrected)

表格识别

通过微软的table-transformer-detection进行表格,该模型可在Hugging Face 官网下载。

图片裁剪

通过PIL里的Image的crop方法对指定的let_top,right_bottom进行裁剪。

相关代码见下:

from PIL import Image
import matplotlib.pyplot as plt
file_path = r'D:\temp\pics\efb.jpg'
image = Image.open(file_path).convert("RGB")
width, height = image.size
image.resize((int(width * 0.5), int(height * 0.5)))
from transformers import DetrFeatureExtractor

feature_extractor = DetrFeatureExtractor()
encoding = feature_extractor(image, return_tensors="pt")
encoding.keys()
from transformers import TableTransformerForObjectDetection
model = TableTransformerForObjectDetection.from_pretrained(r"D:\Modles\table-transformer-detection/")
import torch

with torch.no_grad():
    outputs = model(**encoding)
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]


def plot_results(pil_img, scores, labels, boxes):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        text = f'{model.config.id2label[label]}: {score:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

if __name__ == "__main__":
    width, height = image.size
    results = feature_extractor.post_process_object_detection(outputs, threshold=0.2, target_sizes=[(height, width)])[0]
    plot_results(image, results['scores'], results['labels'], results['boxes'])
    print(results['scores'])
    print(results['labels'])
    print(results['boxes'])
    print(results['boxes'][0][0],type((results['boxes'][0][0])))
    x0=int(results['boxes'][0][0].item())-50
    y0=int(results['boxes'][0][1].item())-50
    x1=int(results['boxes'][0][2].item())+50
    y1=int(results['boxes'][0][3].item())+50
    img2 = image.crop((x0,y0,x1,y1))
    img2.save(r"D:\\efb.jpg")


网站公告

今日签到

点亮在社区的每一天
去签到