多线程处理python

发布于:2024-05-16 ⋅ 阅读:(42) ⋅ 点赞:(0)

今天要统计一下训练数据集中图片的均值和标准差,里面有四万近五万个数据,单线程执行太费事,所以我写了个单线程的代码,让chatgpt给我改成多线程的,跑通了,非常不错,记录一下。

单线程:

# 统计每张图片的均值,标准差
import os
imgs = os.listdir("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")

from PIL import Image
from pathlib import Path
import numpy as np

mean = np.array([0, 0, 0])
std = np.array([0, 0, 0])

path = Path("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")
for img in imgs:
    img_np = np.array(Image.open(path / img))
    mean += img_np.mean((0, 1))
    std += img_np.std((0, 1))

mean /= len(imgs)
std /= np.sqrt(len(imgs))

改成多线程之后:

import os
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
from pathlib import Path
import numpy as np

def calculate_mean_std(img):
    path = Path("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")
    img_np = np.array(Image.open(path / img))
    return img_np.mean((0, 1)), img_np.std((0, 1))

imgs = os.listdir("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")

mean_sum = np.array([0, 0, 0])
std_sum = np.array([0, 0, 0])

with ThreadPoolExecutor() as executor:
    results = list(executor.map(calculate_mean_std, imgs))
    for mean, std in results:
        mean_sum = np.add(mean_sum, mean, out=mean_sum, casting="unsafe")
        std_sum = np.add(std_sum, std, out=std_sum, casting="unsafe")

mean = mean_sum / len(imgs)
std = std_sum / len(imgs)
print("Mean:", mean)
print("Standard Deviation:", std)

总体来说,就加了个`from concurrent.futures import ThreadPoolExecutor`然后包装一个计算的函数`calculate_mean_std`,在`with ThreadPoolExecutor() as executor:`环境下执行`executor.map(calculate_mean_std, imgs)`就行了。总体来说非常简单。

这里`ThreadPoolExecutor()`的参数可以选择`max_workers`,没指定的话就默认给你选最多的。

linux下查看线程占用情况可以使用命令:`htop`

可以看到,线程都被我占了🤏