以下代码主要用于从 ./*.mp4
的文件夹中,每个视频中抽取第N帧保存成图,用于图生视频训练,考虑到数据量比较大,推荐使用ffmpeg
来实现的,性能可以比较高(10w个视频差不多十多分钟就可以跑完),以下:
import os
import subprocess
import time
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import psutil
import logging
from datetime import datetime
# 配置日志
def setup_logging():
log_dir = "logs"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_file = os.path.join(log_dir, f"video_process_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
# 检查系统资源
def check_system_resources():
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
disk = psutil.disk_usage('/')
disk_percent = disk.percent
return {
'cpu': cpu_percent,
'memory': memory_percent,
'disk': disk_percent,
'ok': cpu_percent < 90 and memory_percent < 90 and disk_percent < 90
}
def extract_frame_ffmpeg(video_path, output_path, frame_number=10):
"""
从视频中提取指定帧
"""
try:
video_filename = os.path.splitext(os.path.basename(video_path))[0]
output_filename = f"{video_filename}.jpg"
output_filepath = os.path.join(output_path, output_filename)
# 检查输出文件是否已存在
if os.path.exists(output_filepath):
logging.info(f"文件已存在,跳过处理: {output_filepath}")
return True, video_path
cmd = [
'ffmpeg',
'-i', video_path,
'-vf', f'select=eq(n\,{frame_number-1})',
'-vframes', '1',
'-q:v', '2', # 最高质量
'-y',
output_filepath
]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True
)
stdout, stderr = process.communicate(timeout=30)
if process.returncode == 0:
return True, video_path
else:
return False, f"Error processing {video_path}: {stderr}"
except subprocess.TimeoutExpired:
process.kill()
return False, f"Timeout processing {video_path}"
except Exception as e:
return False, f"Exception processing {video_path}: {str(e)}"
def process_video_worker(args):
"""
工作进程的处理函数
"""
video_path, output_dir = args
return extract_frame_ffmpeg(video_path, output_dir)
def process_videos_parallel(input_dir, output_dir, num_processes=None, batch_size=100):
"""
并行处理视频文件
"""
# 创建输出目录
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 检查系统资源
resources = check_system_resources()
if not resources['ok']:
logging.warning(f"系统资源使用率较高: CPU {resources['cpu']}%, "
f"内存 {resources['memory']}%, 磁盘 {resources['disk']}%")
# 设置进程数
if num_processes is None:
num_processes = min(cpu_count(), 8) # 最多使用8个进程
# 获取所有视频文件
video_files = [f for f in os.listdir(input_dir) if f.endswith('.mp4')]
total_videos = len(video_files)
if total_videos == 0:
logging.warning(f"在 {input_dir} 中没有找到MP4文件")
return
# 准备任务列表
tasks = [(os.path.join(input_dir, video), output_dir)
for video in video_files]
# 处理统计
successful = 0
failed = 0
failed_videos = []
# 分批处理
for i in range(0, len(tasks), batch_size):
batch = tasks[i:i + batch_size]
logging.info(f"处理批次 {i//batch_size + 1}/{(len(tasks)-1)//batch_size + 1}, "
f"包含 {len(batch)} 个视频")
with Pool(processes=num_processes) as pool:
# 使用tqdm创建进度条
for success, message in tqdm(
pool.imap_unordered(process_video_worker, batch),
total=len(batch),
desc="处理进度"
):
if success:
successful += 1
else:
failed += 1
failed_videos.append(message)
logging.error(message)
# 每批处理完后检查系统资源
resources = check_system_resources()
if not resources['ok']:
logging.warning(f"系统资源使用率较高: CPU {resources['cpu']}%, "
f"内存 {resources['memory']}%, 磁盘 {resources['disk']}%")
time.sleep(5) # 给系统一些恢复时间
# 输出最终统计信息
logging.info("\n处理完成统计:")
logging.info(f"总计视频: {total_videos}")
logging.info(f"成功处理: {successful}")
logging.info(f"处理失败: {failed}")
if failed > 0:
logging.info("\n失败的视频:")
for msg in failed_videos:
logging.info(msg)
def main():
# 设置日志
setup_logging()
# 配置参数
input_directory = "./videos" # 输入视频目录
output_directory = "./frames" # 输出图片目录
num_processes = 4 # 进程数
batch_size = 50 # 每批处理的视频数量
# 记录开始时间
start_time = time.time()
try:
# 处理视频
process_videos_parallel(
input_directory,
output_directory,
num_processes=num_processes,
batch_size=batch_size
)
# 计算总耗时
elapsed_time = time.time() - start_time
logging.info(f"\n总耗时: {elapsed_time:.2f} 秒")
except KeyboardInterrupt:
logging.warning("\n用户中断处理")
except Exception as e:
logging.error(f"处理过程中发生错误: {str(e)}")
if __name__ == '__main__':
main()