bytetrack漏检补齐
1.人流慢速运动,跟踪效果比较好,偶尔有漏检,跟踪可以自动补齐。
2.快速运动,频繁遮挡,效果可能不好
*如果漏检,倒着跟踪,把丢失的检测框拷贝出来,保留进行跟踪。
有时候效果不是很好
from collections import defaultdict
import cv2
import numpy as np
import torchvision
from ultralytics import YOLO
import pickle
import torch
from torchvision.ops import box_iou
from log import logger
import time
import os
from addict import Dict
from track.byte_tracker import BYTETracker
import math
def get_color(idx):
idx = idx * 5
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
class YOLO_Class():
def __init__(self, model_path, device="cuda:0"):
self.model = YOLO(model_path) # YOLO‑12 检测 + 跟踪
self.par_args = Dict(
{"track_thresh": 0.5, "track_buffer": 30, "match_thresh": 0.9, "min_box_area": 10, "mot20": False})
self.tracker = BYTETracker(self.par_args, frame_rate=20)
def yolo_byte_track(self,detect_bboxes, frame):
title_color = (0, 255, 255)
person_sum = 0
# print(f"bboxes: {detect_bboxes}")
if len(detect_bboxes) > 0:
if len(detect_bboxes) > 4:
self.par_args.track_buffer = 60
self.par_args.match_thresh = 1.6
else:
self.par_args.track_buffer = 30
self.par_args.match_thresh = 0.9
online_targets = self.tracker.update(np.array(detect_bboxes), [frame.shape[0], frame.shape[1]],
(frame.shape[0], frame.shape[1]), self.par_args)
# print("len(det)", len(detect_bboxes), "len track", len(online_targets))
for index, t in enumerate(online_targets):
tlwh = t.tlwh
x1, y1, w, h = tlwh
if w > 0 and h > 0:
bbbb = t.track_id
person_sum = max(person_sum, bbbb)
box_color = get_color(t.track_id)
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
cv2.rectangle(frame, intbox[0:2], intbox[2:4], color=box_color, thickness=2)
hull = [[x1, y1], [x1 + w, y1], [x1 + w, y1 + h], [x1, y1 + h]]
# for index, point in enumerate(track_dict[bbbb]):
# dist = cv2.pointPolygonTest(np.array(hull).astype(np.int32), tuple(point), True)#<0 out >0 in
# if index==len(track_dict[t.track_id])-4 and t.track_id < 3:
# print('----------------', abs(point[0] - (x1 + w / 2)), abs(point[1]-(y1+h)))
# cv2.rectangle(frame, (intbox[2],intbox[1]), (int(intbox[2]+70),int(intbox[1]+80)), color=box_color , thickness=1)
cv2.putText(frame, f'{bbbb} {t.score:.2f} ', (intbox[0], intbox[1] - 5), cv2.FONT_HERSHEY_PLAIN,
1.8, title_color, thickness=2)
return frame
def get_bytetrack_bbox(self, video_path, video_id, output_path="", debug:bool=False):
debug_dir = f"yolov12/debug/{video_id}" if debug else None
os.makedirs(debug_dir, exist_ok=True) # 确保调试目录存在
# ----------------- 基本参数 -----------------
track_history = defaultdict(list) # 保存每个 track 的历史中心点
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"无法打开视频: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS) or 30 # 有些文件读不到 FPS,给默认
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
logger.info(f"视频总帧数: {total_frames}, fps: {fps}, 宽: {w}, 高: {h}")
frame_id = 0
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
if not out.isOpened():
raise RuntimeError("VideoWriter 初始化失败,请检查编码器 fourcc 或路径。")
last_box=[]
while cap.isOpened():
ok, frame = cap.read()
if not ok:
break
# YOLO11 跟踪(persist=True 保持 track ID)
t0 = time.time()
results = detect_image_yolo(self.model,frame)
pic_h,pic_w = frame.shape[:2]
# if frame_id%4==3:
# results = np.delete(results, 1, axis=0)
pad_count = len(last_box) - len(results)
if pad_count>0 and 0:
tracker2 = BYTETracker(self.par_args, frame_rate=3)
track_now = tracker2.update(results, (pic_h,pic_w),(pic_h,pic_w), self.par_args)
track_last = tracker2.update(last_box, (pic_h,pic_w),(pic_h,pic_w), self.par_args)
last_ids = set(t.track_id for t in track_last)
b_ids = set(t.track_id for t in track_now)
# 找出 a 中比 b 多出来的所有 track_id
extra_ids = last_ids - b_ids
# 根据 track_id 提取出对应的完整对象(如 STrack)
extra_targets = [t for t in track_last if t.track_id in extra_ids]
for t in extra_targets:
x1, y1, w, h = t.tlwh
print('add box', frame_id,x1, y1, w, h)
box_lost=np.asarray([x1, y1, x1 + w, y1 + h,t.score,0])
results = np.vstack([results, box_lost])
last_box=results
t1 = time.time()
frame = self.yolo_byte_track(results, frame)
print(f"{frame_id} det_track time {time.time() - t0:.3f}s track_time {time.time() - t1:.3f}s")
if np.prod(frame.shape[:2]) > 1000 * 1300:
x_scale = np.sqrt(1000 * 1200 / np.prod(frame.shape[:2]))
frame = cv2.resize(frame, None, fx=x_scale, fy=x_scale, interpolation=cv2.INTER_AREA)
cv2.imshow("YOLO Track", frame)
if cv2.waitKey(0) & 0xFF == 27: # Esc to quit
break
# 写入输出视频
out.write(frame)
frame_id += 1
def detect_image_yolo(yolo_model,image, imgsz=640, conf=0.4, min_area=60*40, max_len=0):
with torch.no_grad():
results = yolo_model(image, verbose=False, imgsz=imgsz, conf=conf)
cls = results[0].boxes.cls.int().cpu()
indices = torch.where(cls == 0)[0] # 只保留 person 类别
if len(indices) == 0:
return np.empty((0, 6)) # 返回空但保持 shape 正确
labels = results[0].boxes.cls[indices]
boxes = results[0].boxes.xyxy[indices]
scores = results[0].boxes.conf[indices]
if len(boxes) == 0:
return np.empty((0, 6))
boxes = boxes.float()
keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold=0.5)
boxes = boxes[keep_indices]
scores = scores[keep_indices]
labels = labels[keep_indices]
#面积过滤
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
area_mask = areas >= min_area
boxes = boxes[area_mask]
scores = scores[area_mask]
labels = labels[area_mask]
if len(boxes) == 0:
return np.empty((0, 6))
# 转换为 numpy 并拼接成 ByteTrack 格式
boxes = boxes.cpu().numpy()
scores = scores.cpu().numpy()
labels = labels.cpu().numpy()
dets = np.concatenate([boxes, scores[:, None], labels[:, None]], axis=1) # [N, 6]
return dets
if __name__ == "__main__":
mp4_path = r"C:\Users\Administrator\Videos\yundong\20250226162704517\20250226162704517.mp4"
mp4_path = r"F:\data\lanqiu\150_30\150_30.mp4"
mp4_path = r"E:\data\tiaosheng\0706\5s.mp4"
video_id = os.path.basename(mp4_path).split(".")[0] # 从路径中提取视频 ID
yolo_path= r"F:\BaiduNetdiskDownload\tiaosheng_new\model\best_new.pt"
yolo_cls = YOLO_Class(yolo_path)
yolo_cls.get_bytetrack_bbox(mp4_path, video_id, output_path=f"{video_id}_tracked.mp4", debug=True)