bytetrack漏检补齐

发布于:2025-07-16 ⋅ 阅读:(20) ⋅ 点赞:(0)

bytetrack漏检补齐

1.人流慢速运动,跟踪效果比较好,偶尔有漏检,跟踪可以自动补齐。

2.快速运动,频繁遮挡,效果可能不好

*如果漏检,倒着跟踪,把丢失的检测框拷贝出来,保留进行跟踪。

有时候效果不是很好

from collections import defaultdict
import cv2
import numpy as np
import torchvision
from ultralytics import YOLO
import pickle
import torch
from torchvision.ops import box_iou
from log import logger
import time
import os
from addict import Dict
from track.byte_tracker import BYTETracker
import math

def get_color(idx):
    idx = idx * 5
    color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
    return color



class YOLO_Class():

    def __init__(self, model_path, device="cuda:0"):
        self.model = YOLO(model_path)  # YOLO‑12 检测 + 跟踪
        self.par_args = Dict(
            {"track_thresh": 0.5, "track_buffer": 30, "match_thresh": 0.9, "min_box_area": 10, "mot20": False})
        self.tracker = BYTETracker(self.par_args, frame_rate=20)

    def yolo_byte_track(self,detect_bboxes, frame):

        title_color = (0, 255, 255)
        person_sum = 0
        # print(f"bboxes: {detect_bboxes}")
        if len(detect_bboxes) > 0:
            if len(detect_bboxes) > 4:
                self.par_args.track_buffer = 60
                self.par_args.match_thresh = 1.6
            else:
                self.par_args.track_buffer = 30
                self.par_args.match_thresh = 0.9
            online_targets = self.tracker.update(np.array(detect_bboxes), [frame.shape[0], frame.shape[1]],
                                            (frame.shape[0], frame.shape[1]), self.par_args)
            # print("len(det)", len(detect_bboxes), "len track", len(online_targets))

            for index, t in enumerate(online_targets):
                tlwh = t.tlwh
                x1, y1, w, h = tlwh
                if w > 0 and h > 0:
                    bbbb = t.track_id
                    person_sum = max(person_sum, bbbb)

                    box_color = get_color(t.track_id)
                    intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
                    cv2.rectangle(frame, intbox[0:2], intbox[2:4], color=box_color, thickness=2)
                    hull = [[x1, y1], [x1 + w, y1], [x1 + w, y1 + h], [x1, y1 + h]]

                    # for index, point in enumerate(track_dict[bbbb]):
                    # dist = cv2.pointPolygonTest(np.array(hull).astype(np.int32), tuple(point), True)#<0 out >0 in
                    # if index==len(track_dict[t.track_id])-4 and t.track_id < 3:
                    #     print('----------------', abs(point[0] - (x1 + w / 2)), abs(point[1]-(y1+h)))

                    # cv2.rectangle(frame, (intbox[2],intbox[1]), (int(intbox[2]+70),int(intbox[1]+80)), color=box_color , thickness=1)
                    cv2.putText(frame, f'{bbbb} {t.score:.2f} ', (intbox[0], intbox[1] - 5), cv2.FONT_HERSHEY_PLAIN,
                                1.8, title_color, thickness=2)
        return frame
    def get_bytetrack_bbox(self, video_path, video_id, output_path="", debug:bool=False):
        debug_dir = f"yolov12/debug/{video_id}" if debug else None
        os.makedirs(debug_dir, exist_ok=True)  # 确保调试目录存在
        # ----------------- 基本参数 -----------------
        track_history = defaultdict(list)  # 保存每个 track 的历史中心点

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise RuntimeError(f"无法打开视频: {video_path}")
        fps = cap.get(cv2.CAP_PROP_FPS) or 30  # 有些文件读不到 FPS,给默认
        w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        logger.info(f"视频总帧数: {total_frames}, fps: {fps}, 宽: {w}, 高: {h}")

        frame_id = 0
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
        if not out.isOpened():
            raise RuntimeError("VideoWriter 初始化失败,请检查编码器 fourcc 或路径。")

        last_box=[]
        while cap.isOpened():
            ok, frame = cap.read()
            if not ok:
                break
            # YOLO11 跟踪(persist=True 保持 track ID)
            t0 = time.time()
            results = detect_image_yolo(self.model,frame)
            pic_h,pic_w = frame.shape[:2]  
            # if frame_id%4==3:
            #     results = np.delete(results, 1, axis=0)

            pad_count = len(last_box) - len(results)
            if pad_count>0 and 0:
                tracker2 = BYTETracker(self.par_args, frame_rate=3)
                track_now = tracker2.update(results, (pic_h,pic_w),(pic_h,pic_w), self.par_args)

                track_last = tracker2.update(last_box, (pic_h,pic_w),(pic_h,pic_w), self.par_args)

                last_ids = set(t.track_id for t in track_last)
                b_ids = set(t.track_id for t in track_now)

                # 找出 a 中比 b 多出来的所有 track_id
                extra_ids = last_ids - b_ids
                # 根据 track_id 提取出对应的完整对象(如 STrack)
                extra_targets = [t for t in track_last if t.track_id in extra_ids]

                for t in extra_targets:
                    x1, y1, w, h = t.tlwh

                    print('add box', frame_id,x1, y1, w, h)
                    box_lost=np.asarray([x1, y1, x1 + w, y1 + h,t.score,0])
                    results = np.vstack([results, box_lost])
            last_box=results
            t1 = time.time()
            frame = self.yolo_byte_track(results, frame)
            print(f"{frame_id} det_track time {time.time() - t0:.3f}s track_time {time.time() - t1:.3f}s")
            if np.prod(frame.shape[:2]) > 1000 * 1300:
                x_scale = np.sqrt(1000 * 1200 / np.prod(frame.shape[:2]))
                frame = cv2.resize(frame, None, fx=x_scale, fy=x_scale, interpolation=cv2.INTER_AREA)

            cv2.imshow("YOLO Track", frame)
            if cv2.waitKey(0) & 0xFF == 27:   # Esc to quit
                break
            # 写入输出视频
            out.write(frame)
            frame_id += 1


def detect_image_yolo(yolo_model,image, imgsz=640, conf=0.4, min_area=60*40, max_len=0):
    with torch.no_grad():
        results = yolo_model(image, verbose=False, imgsz=imgsz, conf=conf)

    cls = results[0].boxes.cls.int().cpu()
    indices = torch.where(cls == 0)[0]  # 只保留 person 类别

    if len(indices) == 0:
        return np.empty((0, 6))  # 返回空但保持 shape 正确

    labels = results[0].boxes.cls[indices]
    boxes = results[0].boxes.xyxy[indices]
    scores = results[0].boxes.conf[indices]

    if len(boxes) == 0:
        return np.empty((0, 6))

    boxes = boxes.float()
    keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold=0.5)
    boxes = boxes[keep_indices]
    scores = scores[keep_indices]
    labels = labels[keep_indices]

    #面积过滤
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    area_mask = areas >= min_area

    boxes = boxes[area_mask]
    scores = scores[area_mask]
    labels = labels[area_mask]

    if len(boxes) == 0:
        return np.empty((0, 6))

    # 转换为 numpy 并拼接成 ByteTrack 格式
    boxes = boxes.cpu().numpy()
    scores = scores.cpu().numpy()
    labels = labels.cpu().numpy()

    dets = np.concatenate([boxes, scores[:, None], labels[:, None]], axis=1)  # [N, 6]
    return dets


if __name__ == "__main__":




    mp4_path = r"C:\Users\Administrator\Videos\yundong\20250226162704517\20250226162704517.mp4"
    mp4_path = r"F:\data\lanqiu\150_30\150_30.mp4"
    mp4_path = r"E:\data\tiaosheng\0706\5s.mp4"
    video_id = os.path.basename(mp4_path).split(".")[0]  # 从路径中提取视频 ID
    yolo_path= r"F:\BaiduNetdiskDownload\tiaosheng_new\model\best_new.pt"
    yolo_cls = YOLO_Class(yolo_path)
    yolo_cls.get_bytetrack_bbox(mp4_path, video_id, output_path=f"{video_id}_tracked.mp4", debug=True)


网站公告

今日签到

点亮在社区的每一天
去签到