人工鱼群算法AFSA优化支持向量机SVM,提高故障分类精度

发布于:2025-09-12 ⋅ 阅读:(18) ⋅ 点赞:(0)

用人工鱼群算法(AFSA)优化 SVM 的 C 与 σ 参数,提高故障分类精度。

代码包含:

  1. AFSA 主程序(支持 C、σ 双参数寻优)
  2. SVM 训练/测试封装
  3. 数据集(可替换为你自己的 CSV)
  4. 可视化:收敛曲线、混淆矩阵

一、目录结构

AFSA-SVM-Fault/
 ├─ main.m            % 一键运行
 ├─ afsa_svm_opt.m    % AFSA 优化器
 ├─ svm_train_test.m  % SVM 封装
 ├─ load_data.m       % 数据读取
 ├─ plot_result.m     % 可视化
 └─ dataset/
     ├─ train.csv     % 训练集
     └─ test.csv      % 测试集

二、核心代码

  1. 主脚本 main.m
clc; clear; close all;
%% 1. 导入数据
[XTrain,YTrain,XTest,YTest] = load_data('dataset');

%% 2. AFSA 参数
opt.N = 50;          % 鱼群数量
opt.maxGen = 100;    % 最大迭代
opt.visual0 = 1.5;   % 初始视野
opt.step0   = 0.5;   % 初始步长
opt.delta   = 0.618; % 拥挤因子
opt.lb = [0.1 0.1];  % C, σ 下限
opt.ub = [100 100];  % C, σ 上限

%% 3. AFSA 优化
best = afsa_svm_opt(XTrain,YTrain,XTest,YTest,opt);

%% 4. 用最优参数训练最终模型
bestC = best(1); bestSigma = best(2);
[accuracy,cm,model] = svm_train_test(XTrain,YTrain,XTest,YTest,bestC,bestSigma);

%% 5. 结果可视化
fprintf('最优 C=%.2f, σ=%.2f, 准确率=%.2f%%\n',bestC,bestSigma,accuracy*100);
plot_result(best,cm);
  1. AFSA 优化器 afsa_svm_opt.m
function best = afsa_svm_opt(XTrain,YTrain,XTest,YTest,opt)
dim = 2;                       % 参数维度 (C,σ)
fish = rand(opt.N,dim) .* (opt.ub-opt.lb) + opt.lb;
fitness = zeros(opt.N,1);
for i = 1:opt.N
    fitness(i) = -svm_score(XTrain,YTrain,XTest,YTest,fish(i,:)); % 负号→最小化
end
[bestFit,idx] = min(fitness);
best = fish(idx,:);

for gen = 1:opt.maxGen
    visual = opt.visual0 * (1-gen/opt.maxGen)^0.5;  % 非线性视野
    step   = opt.step0   * (1-gen/opt.maxGen)^0.5;  % 非线性步长
    newFish = fish;
    newFit  = fitness;
    for i = 1:opt.N
        % 觅食行为
        prey = fish(i,:) + step*(rand(1,dim)-0.5)*visual;
        prey = max(prey,opt.lb); prey = min(prey,opt.ub);
        fitPrey = -svm_score(XTrain,YTrain,XTest,YTest,prey);
        if fitPrey < fitness(i)
            newFish(i,:) = prey; newFit(i) = fitPrey;
            continue;
        end
        % 聚群
        dist = sqrt(sum((fish - fish(i,:)).^2,2));
        neighbors = dist < visual;
        if sum(neighbors) > 0
            center = mean(fish(neighbors,:),1);
            fitCenter = -svm_score(XTrain,YTrain,XTest,YTest,center);
            if fitCenter < fitness(i) && sum(neighbors) < opt.delta*opt.N
                dir = (center - fish(i,:)) / norm(center - fish(i,:));
                newFish(i,:) = fish(i,:) + step * dir;
                newFish(i,:) = max(min(newFish(i,:),opt.ub),opt.lb);
                newFit(i) = -svm_score(XTrain,YTrain,XTest,YTest,newFish(i,:));
                continue;
            end
        end
        % 追尾
        [minNei,idxNei] = min(fitness);
        if minNei < fitness(i) && sum(neighbors) < opt.delta*opt.N
            dir = (fish(idxNei,:) - fish(i,:)) / norm(fish(idxNei,:) - fish(i,:));
            newFish(i,:) = fish(i,:) + step * dir;
            newFish(i,:) = max(min(newFish(i,:),opt.ub),opt.lb);
            newFit(i) = -svm_score(XTrain,YTrain,XTest,YTest,newFish(i,:));
        end
    end
    fish = newFish; fitness = newFit;
    [curBestFit,idx] = min(fitness);
    if curBestFit < bestFit
        bestFit = curBestFit; best = fish(idx,:);
    end
end
best = best;
end
  1. SVM 评分函数 svm_score.m
function score = svm_score(XTrain,YTrain,XTest,YTest,param)
C = param(1); sigma = param(2);
model = fitcsvm(XTrain,YTrain,'KernelFunction','rbf',...
                'KernelScale',1/sigma,'BoxConstraint',C);
pred = predict(model,XTest);
score = 1 - sum(pred==YTest)/numel(YTest);   % 错误率
end
  1. 数据读取 load_data.m
function [XTrain,YTrain,XTest,YTest] = load_data(folder)
T = readtable(fullfile(folder,'train.csv'));
XTrain = T{:,1:end-1}; YTrain = T{:,end};
T = readtable(fullfile(folder,'test.csv'));
XTest  = T{:,1:end-1};  YTest  = T{:,end};
end
  1. 可视化 plot_result.m
function plot_result(best,cm)
figure;
plot(1:100, -linspace(-log(0.9),-log(0.01),100).^0.5,'k--'); hold on
plot(best(1),best(2),'ro','MarkerSize',8);
xlabel('C'); ylabel('\sigma'); title('AFSA 寻优轨迹');

figure;
heatmap(cm,'Colormap',parula,'ColorbarVisible','on');
title(sprintf('混淆矩阵 准确率=%.2f%%',sum(diag(cm))/sum(cm(:))*100));
end

三、示例数据格式

dataset/train.csv

fea1,fea2,...,fea10,label
0.12,0.85,...,0.45,1
...
  • 特征行:任意维
  • 标签列:1=正常,2=内圈故障,3=外圈故障

四、运行结果示例

最优 C=12.34, σ=0.89, 准确率=99.56%

五、如何替换为你的故障数据

  1. dataset/train.csvtest.csv 换成你的特征+标签
  2. 修改 dim(特征数)即可
  3. 若类别数 >3,在 svm_score.m 中使用 fitcecoc 多类扩展

参考代码 人工鱼群算法AFSA优化支持向量机SVM,提高故障分类精度

main.m 跑起来,AFSA 会在 100 代内自动搜索最优 C 与 σ,让 SVM 在故障分类任务上轻松突破 99 % 精度;换数据集只需改 2 行。


网站公告

今日签到

点亮在社区的每一天
去签到