视频抽取关键帧算法

发布于:2025-08-13 ⋅ 阅读:(19) ⋅ 点赞:(0)

可直接运行代码:
https://colab.research.google.com/drive/1iXgzIB8k-_ZpgCiGn-r9WgU7mvdRrKVB?usp=sharing

1. 计算帧间差分,取局部极大值(抽帧较少)

# -*- coding: utf-8 -*-
import cv2
import operator
import numpy as np
import matplotlib.pyplot as plt
import sys
from scipy.signal import argrelextrema
import os

def smooth(x, window_len=13, window='hanning'):
    print(len(x), window_len)
    s = np.r_[2 * x[0] - x[window_len:1:-1],
              x, 2 * x[-1] - x[-1:-window_len:-1]]

    if window == 'flat':  # moving average
        w = np.ones(window_len, 'd')
    else:
        w = getattr(np, window)(window_len)
    y = np.convolve(w / w.sum(), s, mode='same')
    return y[window_len - 1:-window_len + 1]


class Frame:
    def __init__(self, id, diff):
        self.id = id
        self.diff = diff

    def __lt__(self, other):
        if self.id == other.id:
            return self.id < other.id
        return self.id < other.id

    def __gt__(self, other):
        return other.__lt__(self)

    def __eq__(self, other):
        return self.id == other.id and self.id == other.id

    def __ne__(self, other):
        return not self.__eq__(other)


def rel_change(a, b):
    x = (b - a) / max(a, b)
    print(x)
    return x


if __name__ == "__main__":
    print(sys.executable)
    # Setting fixed threshold criteria
    USE_THRESH = False
    # fixed threshold value
    THRESH = 0.6
    # Setting fixed threshold criteria
    USE_TOP_ORDER = False
    # Setting local maxima criteria
    USE_LOCAL_MAXIMA = True
    # Number of top sorted frames
    NUM_TOP_FRAMES = 50

    # 遍历当前目录下的所有MP4文件
    for filename in os.listdir("."):
        if filename.endswith(".mp4"):
            videopath = filename  # 当前目录下的MP4文件
            name = os.path.splitext(filename)[0]  # 文件名(不带扩展名)
            dir = f"./extract_result/{name}/"  # 保存关键帧的目录
            os.makedirs(dir, exist_ok=True)  # 创建目录
            len_window = int(50)  # 平滑窗口大小

            print("Target video :" + videopath)
            print("Frame save directory: " + dir)
            # load video and compute diff between frames
            cap = cv2.VideoCapture(str(videopath))
            curr_frame = None
            prev_frame = None
            frame_diffs = []
            frames = []
            success, frame = cap.read()
            i = 0
            while success:
                luv = cv2.cvtColor(frame, cv2.COLOR_BGR2LUV)
                curr_frame = luv
                if curr_frame is not None and prev_frame is not None:
                    # logic here
                    diff = cv2.absdiff(curr_frame, prev_frame)
                    diff_sum = np.sum(diff)
                    diff_sum_mean = diff_sum / (diff.shape[0] * diff.shape[1])
                    frame_diffs.append(diff_sum_mean)
                    frame = Frame(i, diff_sum_mean)
                    frames.append(frame)
                prev_frame = curr_frame
                i = i + 1
                success, frame = cap.read()
            cap.release()

            # compute keyframe
            keyframe_id_set = set()
            if USE_TOP_ORDER:
                # sort the list in descending order
                frames.sort(key=operator.attrgetter("diff"), reverse=True)
                for keyframe in frames[:NUM_TOP_FRAMES]:
                    keyframe_id_set.add(keyframe.id)
            if USE_THRESH:
                print("Using Threshold")
                for i in range(1, len(frames)):
                    if (rel_change(np.float(frames[i - 1].diff), np.float(frames[i].diff)) >= THRESH):
                        keyframe_id_set.add(frames[i].id)
            if USE_LOCAL_MAXIMA:
                print("Using Local Maxima")
                diff_array = np.array(frame_diffs)
                sm_diff_array = smooth(diff_array, len_window)
                frame_indexes = np.asarray(argrelextrema(sm_diff_array, np.greater))[0]
                for i in frame_indexes:
                    keyframe_id_set.add(frames[i - 1].id)

                # Plot the smoothed differences
                plt.figure(figsize=(40, 20))
                plt.gca().xaxis.set_major_locator(plt.MaxNLocator(100))  # Set number of x-axis ticks
                plt.gca().yaxis.set_major_locator(plt.MaxNLocator(10))   # Optionally set number of y-axis ticks
                plt.stem(sm_diff_array)
                plt.savefig(dir + 'plot.png')

            # save all keyframes as image
            cap = cv2.VideoCapture(str(videopath))
            curr_frame = None
            keyframes = []
            success, frame = cap.read()
            idx = 0
            while success:
                if idx in keyframe_id_set:
                    name = "keyframe_" + str(idx) + ".jpg"
                    cv2.imwrite(dir + name, frame)
                    keyframe_id_set.remove(idx)
                idx = idx + 1
                success, frame = cap.read()
            cap.release()
            print(f"关键帧已保存到:{dir}")

2.基于光流方法 (结果为保存帧信息的json文件)

import cv2
import json
import os
import numpy as np

def getInfo(sourcePath):
    cap = cv2.VideoCapture(sourcePath)
    info = {
        "framecount": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
        "fps": cap.get(cv2.CAP_PROP_FPS),
        "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
        "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
        "codec": int(cap.get(cv2.CAP_PROP_FOURCC))
    }
    cap.release()
    return info

def scale(img, xScale, yScale):
    return cv2.resize(img, None, fx=xScale, fy=yScale, interpolation=cv2.INTER_AREA)

def resize(img, width, height):
    return cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)

def extract_cols(image, numCols):
    Z = image.reshape((-1, 3)).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0)
    _, labels, centers = cv2.kmeans(Z, numCols, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)

    clusterCounts = [int(np.sum(labels == i)) for i in range(numCols)]
    rgbCenters = [center.tolist()[::-1] for center in centers]

    return [{"count": count, "col": col} for count, col in zip(clusterCounts, rgbCenters)]

def calculateFrameStats(sourcePath, after_frame=0):
    cap = cv2.VideoCapture(sourcePath)
    data = {"frame_info": []}
    lastFrame = None

    while cap.isOpened():
        ret, frame = cap.read()
        if frame is None:
            break

        frame_number = int(cap.get(cv2.CAP_PROP_POS_FRAMES) - 1)
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        gray = scale(gray, 0.25, 0.25)
        gray = cv2.GaussianBlur(gray, (9, 9), 0.0)

        if frame_number >= after_frame and lastFrame is not None:
            diff = cv2.absdiff(gray, lastFrame)
            diffMag = int(cv2.countNonZero(diff))  # 转为 Python int
            data["frame_info"].append({"frame_number": frame_number, "diff_count": diffMag})

        lastFrame = gray

    cap.release()

    diff_counts = [fi["diff_count"] for fi in data["frame_info"]]
    if diff_counts:
        data["stats"] = {
            "num": int(len(diff_counts)),
            "min": int(np.min(diff_counts)),
            "max": int(np.max(diff_counts)),
            "mean": float(np.mean(diff_counts)),
            "median": float(np.median(diff_counts)),
            "sd": float(np.std(diff_counts))
        }
    return data

def detectScenes(sourcePath, destPath, data):
    diff_threshold = data["stats"]["sd"] * 2.05 + data["stats"]["mean"]

    cap = cv2.VideoCapture(sourcePath)
    os.makedirs(destPath, exist_ok=True)

    for index, fi in enumerate(data["frame_info"]):
        if fi["diff_count"] < diff_threshold:
            continue

        # 将视频定位到关键帧并读取该帧
        cap.set(cv2.CAP_PROP_POS_FRAMES, fi["frame_number"])
        ret, frame = cap.read()
        if not ret:
            continue

        # 保存关键帧图像到目标文件夹
        frame_filename = os.path.join(destPath, f"key_frame_{fi['frame_number']}.jpg")
        cv2.imwrite(frame_filename, frame)

    cap.release()
    return data

# 遍历当前目录下的所有MP4文件
for filename in os.listdir("."):
    if filename.endswith(".mp4"):
        source = filename  # 当前目录下的MP4文件
        dest = os.path.splitext(filename)[0]  # 以视频文件名创建目标文件夹
        name = os.path.splitext(filename)[0]  # 文件名(不带扩展名)
        after_frame = 0  # 起始帧

        print(f"处理视频: {source}")
        info = getInfo(source)
        print("视频信息: ", info)

        # 计算帧差数据并检测场景变换
        data = calculateFrameStats(source, after_frame)
        data = detectScenes(source, dest, data)

        # 保存元数据
        data_fp = os.path.join(dest, f"{name}-meta.json")
        with open(data_fp, 'w') as f:
            json.dump(data, f, indent=4)

        keyframe_info_fp = os.path.join(dest, f"{name}-keyframe-meta.json")
        keyframeInfo = [frame_info for frame_info in data["frame_info"] if "dominant_cols" in frame_info]
        with open(keyframe_info_fp, 'w') as f:
            json.dump(keyframeInfo, f, indent=4)

        print(f"关键帧数据和图片已保存到:{dest}")

3. 基于颜色直方图聚类 (抽帧较多)

import cv2
import numpy as np
import os

# 遍历当前目录下的所有MP4文件
for filename in os.listdir("."):
    if filename.endswith(".mp4"):
        video_path = filename  # 当前目录下的MP4文件
        name = os.path.splitext(filename)[0]  # 文件名(不带扩展名)
        output_folder = f'key_frames/{name}'  # 保存关键帧的目录
        os.makedirs(output_folder, exist_ok=True)  # 创建目录

        print(f"处理视频: {video_path}")
        print(f"关键帧保存路径: {output_folder}")

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError("无法打开视频文件!")

        # 获取视频帧数
        num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        key = np.zeros(num)  # 初始化关键帧数组
        cluster = np.zeros(num)  # 初始化聚类数组
        cluster_count = np.zeros(num)  # 各聚类的帧数量
        count = 0  # 聚类数量

        threshold = 0.91  # 阈值
        centrodR = np.zeros((num, 256))  # 聚类质心R的直方图
        centrodG = np.zeros((num, 256))  # 聚类质心G的直方图
        centrodB = np.zeros((num, 256))  # 聚类质心B的直方图

        # 读取首帧,形成第一个聚类
        ret, frame = cap.read()
        if not ret:
            raise ValueError("无法读取第一帧!")

        count += 1
        preCountR = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()
        preCountG = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()
        preCountB = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()

        cluster[0] = 1
        cluster_count[0] += 1
        centrodR[0] = preCountR
        centrodG[0] = preCountG
        centrodB[0] = preCountB

        visit = 1

        # 遍历视频的其他帧
        for k in range(1, num):
            ret, frame = cap.read()
            if not ret:
                break

            tmpCountR = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()
            tmpCountG = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()
            tmpCountB = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()

            clusterGroupId = 1
            maxSimilar = 0

            # 计算相似度
            for clusterCountI in range(visit, count + 1):
                sR = np.sum(np.minimum(centrodR[clusterCountI - 1], tmpCountR))
                sG = np.sum(np.minimum(centrodG[clusterCountI - 1], tmpCountG))
                sB = np.sum(np.minimum(centrodB[clusterCountI - 1], tmpCountB))

                dR = sR / np.sum(tmpCountR)
                dG = sG / np.sum(tmpCountG)
                dB = sB / np.sum(tmpCountB)
                d = 0.30 * dR + 0.59 * dG + 0.11 * dB

                if d > maxSimilar:
                    clusterGroupId = clusterCountI
                    maxSimilar = d

            # 判断是否加入现有聚类或形成新聚类
            if maxSimilar > threshold:
                centrodR[clusterGroupId - 1] = (centrodR[clusterGroupId - 1] * cluster_count[clusterGroupId - 1] + tmpCountR) / (cluster_count[clusterGroupId - 1] + 1)
                centrodG[clusterGroupId - 1] = (centrodG[clusterGroupId - 1] * cluster_count[clusterGroupId - 1] + tmpCountG) / (cluster_count[clusterGroupId - 1] + 1)
                centrodB[clusterGroupId - 1] = (centrodB[clusterGroupId - 1] * cluster_count[clusterGroupId - 1] + tmpCountB) / (cluster_count[clusterGroupId - 1] + 1)
                cluster_count[clusterGroupId - 1] += 1
                cluster[k] = clusterGroupId
            else:
                count += 1
                visit += 1
                cluster_count[count - 1] += 1
                centrodR[count - 1] = tmpCountR
                centrodG[count - 1] = tmpCountG
                centrodB[count - 1] = tmpCountB
                cluster[k] = count

        cap.release()

        # 提取每个聚类的关键帧
        max_similarity = np.zeros(count)
        frame_indices = np.zeros(count, dtype=int)

        cap = cv2.VideoCapture(video_path)
        frame_number = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            tmpCountR = cv2.calcHist([frame], [0], None, [256], [0, 256]).flatten()
            tmpCountG = cv2.calcHist([frame], [1], None, [256], [0, 256]).flatten()
            tmpCountB = cv2.calcHist([frame], [2], None, [256], [0, 256]).flatten()

            sR = np.sum(np.minimum(centrodR[int(cluster[frame_number]) - 1], tmpCountR))
            sG = np.sum(np.minimum(centrodG[int(cluster[frame_number]) - 1], tmpCountG))
            sB = np.sum(np.minimum(centrodB[int(cluster[frame_number]) - 1], tmpCountB))

            dR = sR / np.sum(tmpCountR)
            dG = sG / np.sum(tmpCountG)
            dB = sB / np.sum(tmpCountB)
            d = 0.30 * dR + 0.59 * dG + 0.11 * dB

            if d > max_similarity[int(cluster[frame_number]) - 1]:
                max_similarity[int(cluster[frame_number]) - 1] = d
                frame_indices[int(cluster[frame_number]) - 1] = frame_number

            frame_number += 1

        cap.release()

        # 保存关键帧到文件夹
        cap = cv2.VideoCapture(video_path)
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, img = cap.read()
            if ret:
                frame_filename = os.path.join(output_folder, f'key_frame_{int(idx)}.jpg')
                cv2.imwrite(frame_filename, img)
        cap.release()

        print(f"关键帧已保存到:{output_folder}")

网站公告

今日签到

点亮在社区的每一天
去签到