今天要统计一下训练数据集中图片的均值和标准差,里面有四万近五万个数据,单线程执行太费事,所以我写了个单线程的代码,让chatgpt给我改成多线程的,跑通了,非常不错,记录一下。
单线程:
# 统计每张图片的均值,标准差
import os
imgs = os.listdir("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")
from PIL import Image
from pathlib import Path
import numpy as np
mean = np.array([0, 0, 0])
std = np.array([0, 0, 0])
path = Path("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")
for img in imgs:
img_np = np.array(Image.open(path / img))
mean += img_np.mean((0, 1))
std += img_np.std((0, 1))
mean /= len(imgs)
std /= np.sqrt(len(imgs))
改成多线程之后:
import os
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
from pathlib import Path
import numpy as np
def calculate_mean_std(img):
path = Path("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")
img_np = np.array(Image.open(path / img))
return img_np.mean((0, 1)), img_np.std((0, 1))
imgs = os.listdir("/home/jihuawei/jhw2024spring/AceParser/Dataset/PubLayNet/publaynet/train")
mean_sum = np.array([0, 0, 0])
std_sum = np.array([0, 0, 0])
with ThreadPoolExecutor() as executor:
results = list(executor.map(calculate_mean_std, imgs))
for mean, std in results:
mean_sum = np.add(mean_sum, mean, out=mean_sum, casting="unsafe")
std_sum = np.add(std_sum, std, out=std_sum, casting="unsafe")
mean = mean_sum / len(imgs)
std = std_sum / len(imgs)
print("Mean:", mean)
print("Standard Deviation:", std)
总体来说,就加了个`from concurrent.futures import ThreadPoolExecutor`然后包装一个计算的函数`calculate_mean_std`,在`with ThreadPoolExecutor() as executor:`环境下执行`executor.map(calculate_mean_std, imgs)`就行了。总体来说非常简单。
这里`ThreadPoolExecutor()`的参数可以选择`max_workers`,没指定的话就默认给你选最多的。
linux下查看线程占用情况可以使用命令:`htop`
可以看到,线程都被我占了🤏