FPGA多通道卷积加速器:从零构建手写识别的硬件引擎

发布于:2025-06-16 ⋅ 阅读:(15) ⋅ 点赞:(0)

我最近在从事一项很有意思的项目,我想在PFGA上部署CNN并实现手写图片的识别。而本篇文章,是我迈出的第二步。具体代码已发布在github上

模块介绍

卷积神经网络(CNN)可以分为卷积层、池化层、激活层、全链接层结构。本篇实现的,就是CNN的卷积层中的卷积运算模块。

卷积运算的过程如下图所示:

img

在权重参数已经确定的情况下,我们可以将这过程看成数据滑窗卷积运算的这两个步骤的重复运算。在前文中,我们已经实现了window模块,而此处我们实现卷积运算模块。

运算过程如下:
[ 1 2 3 4 5 6 7 8 9 ] ∗ [ 1 0 1 0 1 0 1 1 2 ] = 1 ⋅ 1 + 2 ⋅ 0 + 3 ⋅ 1 + 4 ⋅ 0 + 5 ⋅ 1 + 6 ⋅ 0 + 7 ⋅ 1 + 8 ⋅ 1 + 9 ⋅ 2 = 42 \begin{bmatrix}1&2&3\\4&5&6\\7&8&9\end{bmatrix} \ast \begin{bmatrix}1&0&1\\ 0&1&0\\1&1&2 \end{bmatrix}=1\cdot1 +2 \cdot0 +3\cdot1+4\cdot0+5\cdot 1+6\cdot0+7\cdot1+8\cdot 1+9\cdot 2 \\ =42 147258369 101011102 =11+20+31+40+51+60+71+81+92=42

代码

  1. 模块可配置参数、输入和输出定义

为了支持多通道并行处理,输入为所有输入通道展平后的数据,如一维的窗口数据和权重参数

DATA_WIDTH和WEIGHT_WIDTH分开定义,因为后续工作中会对权重定点数量化

module mult_acc_comb #(
    parameter DATA_WIDTH = 8,
    parameter KERNEL_SIZE = 3,
    parameter IN_CHANNEL = 3,
    parameter WEIGHT_WIDTH = 8,
    parameter OUTPUT_WIDTH = 20,  // 可配置的输出位宽
    parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL) // Ensure ACC_WIDTH is sufficient
)(
    // 输入数据接口
    input window_valid,
    input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in,
    input weight_valid,
    input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in,

    // 输出数据接口
    output [OUTPUT_WIDTH-1:0] conv_out, // 使用可配置的输出位宽
    output conv_valid
);
  1. 定义内部相关信号
// 计算权重相关参数
localparam WEIGHTS_PER_FILTER = IN_CHANNEL * KERNEL_SIZE * KERNEL_SIZE;

// 解包后的多通道窗口数据和权重数据,无符号
wire [DATA_WIDTH-1:0] channel_window_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1]; 
wire [WEIGHT_WIDTH-1:0] channel_weight_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1];  

// 每个通道每个位置的乘法结果,无符号
wire [DATA_WIDTH+WEIGHT_WIDTH-1:0] mult_results [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1]; 

// 每个通道的累加结果
wire [ACC_WIDTH-1:0] channel_sums [0:IN_CHANNEL-1];

// 最终跨通道累加结果
wire [ACC_WIDTH-1:0] total_sum; 

// 循环变量
genvar ch, i_idx, k_idx, c_idx; 
  1. 输入数据解包
generate
    for (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : unpack_gen
        for (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : element_gen
            // 解包窗口数据
            assign channel_window_data[ch][i_idx] = multi_channel_window_in[
                (ch*KERNEL_SIZE*KERNEL_SIZE + i_idx)*DATA_WIDTH +: DATA_WIDTH ];
            // 解包权重数据
            assign channel_weight_data[ch][i_idx] = multi_channel_weight_in[
                (WEIGHTS_PER_FILTER - 1 - (ch*KERNEL_SIZE*KERNEL_SIZE + i_idx))*WEIGHT_WIDTH +: WEIGHT_WIDTH
            ];
        end
    end
endgenerate

a[ b +: c ]的含义是,从a的b位,向上提取c位,也就是a[b+c:b+1];

输入的window和weight的数据结构变化如下

在这里插入图片描述

  1. 并行卷积运算

所有通道同时进行卷积运算

// 并行乘法 - 所有通道所有位置同时计算
generate
    for (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : mult_ch_gen
        for (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : mult_elem_gen
            assign mult_results[ch][i_idx] = channel_window_data[ch][i_idx] * channel_weight_data[ch][i_idx];
        end
    end
endgenerate


// 每个通道内累加 - 使用组合逻辑加法树
generate
    for (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : sum_ch_gen
        if (KERNEL_SIZE == 3) begin : kernel3_sum
            assign channel_sums[ch] = 
                mult_results[ch][0] + mult_results[ch][1] + mult_results[ch][2] +
                mult_results[ch][3] + mult_results[ch][4] + mult_results[ch][5] +
                mult_results[ch][6] + mult_results[ch][7] + mult_results[ch][8];
        end else begin : general_sum
            wire [ACC_WIDTH-1:0] partial_sums [0:KERNEL_SIZE*KERNEL_SIZE-1];
            assign partial_sums[0] = mult_results[ch][0];
            for (k_idx = 1; k_idx < KERNEL_SIZE*KERNEL_SIZE; k_idx = k_idx + 1) begin : acc_gen
                assign partial_sums[k_idx] = partial_sums[k_idx-1] + mult_results[ch][k_idx];
            end
            assign channel_sums[ch] = partial_sums[KERNEL_SIZE*KERNEL_SIZE-1];
        end
    end
endgenerate
  1. 跨通道累加并输出

对所有通道结果进行相加,进行饱和处理,然后输出

// 跨通道累加 - 组合逻辑
generate
    if (IN_CHANNEL == 3) begin : channel3_sum
        assign total_sum = channel_sums[0] + channel_sums[1] + channel_sums[2];
    end else begin : general_channel_sum
        wire [ACC_WIDTH-1:0] channel_partial_sums [0:IN_CHANNEL-1];
        assign channel_partial_sums[0] = channel_sums[0];
        for (c_idx = 1; c_idx < IN_CHANNEL; c_idx = c_idx + 1) begin : ch_acc_gen
            assign channel_partial_sums[c_idx] = channel_partial_sums[c_idx-1] + channel_sums[c_idx];
        end
        assign total_sum = channel_partial_sums[IN_CHANNEL-1];
    end
endgenerate

// 输出逻辑 - 组合逻辑
assign conv_valid = window_valid && weight_valid;
assign conv_out = conv_valid ? saturate(total_sum) : {OUTPUT_WIDTH{1'b0}};

// 饱和处理函数(组合逻辑)- UNSIGNED
function [OUTPUT_WIDTH-1:0] saturate;
    input [ACC_WIDTH-1:0] value; // UNSIGNED
    localparam [ACC_WIDTH-1:0] MAX_UNSIGNED_VAL_SAT = (1 << OUTPUT_WIDTH) - 1;
    // MIN_UNSIGNED_VAL is 0
    begin
        if (value > MAX_UNSIGNED_VAL_SAT)
            saturate = MAX_UNSIGNED_VAL_SAT[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取
        else
            saturate = value[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取
    end
endfunction

测试

mult_acc_comb_tb.v

为验证其功能性,使用多个case经行测试,并对比结果

`timescale 1ns / 1ps

module mult_acc_comb_tb;

parameter DATA_WIDTH = 8;
parameter KERNEL_SIZE = 3;
parameter IN_CHANNEL = 3;
parameter WEIGHT_WIDTH = 8;
parameter OUTPUT_WIDTH = 20;  
parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL);

reg window_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in;
reg weight_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in;

wire [OUTPUT_WIDTH-1:0] conv_out;
wire conv_valid;

localparam MAX_UNSIGNED_OUT_VAL = (1 << OUTPUT_WIDTH) - 1;

// Example: Test 2 raw sum for unsigned context
localparam EXPECTED_SUM_TEST2_UNSIGNED_RAW = 3 * 9 * 2 * 3; // 162
localparam EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT = (EXPECTED_SUM_TEST2_UNSIGNED_RAW > MAX_UNSIGNED_OUT_VAL) ? MAX_UNSIGNED_OUT_VAL : EXPECTED_SUM_TEST2_UNSIGNED_RAW;
localparam MAX_ELEMENT_VAL_TB = (1 << DATA_WIDTH) -1;
localparam MAX_WEIGHT_ELEMENT_VAL_TB = (1 << WEIGHT_WIDTH) -1;

mult_acc_comb #(
    .DATA_WIDTH(DATA_WIDTH),
    .KERNEL_SIZE(KERNEL_SIZE),
    .IN_CHANNEL(IN_CHANNEL),
    .WEIGHT_WIDTH(WEIGHT_WIDTH),
    .OUTPUT_WIDTH(OUTPUT_WIDTH),
    .ACC_WIDTH(ACC_WIDTH)
) dut (
    .window_valid(window_valid),
    .multi_channel_window_in(multi_channel_window_in),
    .weight_valid(weight_valid),
    .multi_channel_weight_in(multi_channel_weight_in),
    .conv_out(conv_out),
    .conv_valid(conv_valid)
);

reg all_tests_passed_flag; 
integer test_id_counter;
integer num_errors;

// Task to check results and display Expected/Actual for all
task check_and_report;
    input [OUTPUT_WIDTH-1:0] expected_out_val;
    input expected_valid_val;
    // Test description is displayed before calling this task
    begin
        test_id_counter = test_id_counter + 1;
        
        // Always display Expected and Actual
        $display("    Expected: conv_valid=%b, conv_out=%d", expected_valid_val, expected_out_val);
        $display("    Actual:   conv_valid=%b, conv_out=%d", conv_valid, conv_out);

        if (conv_valid === expected_valid_val &&
            ( (expected_valid_val === 1'b0) ? (conv_out === {OUTPUT_WIDTH{1'b0}}) : (conv_out === expected_out_val) ) ) begin
            $display("    Test ID %0d: Status: PASSED", test_id_counter);
        end else begin
            $display("    Test ID %0d: Status: FAILED", test_id_counter);
            all_tests_passed_flag = 1'b0;
            num_errors = num_errors + 1;
        end
        $display("--------------------------------------------------");
    end
endtask

initial begin
    $display("=== Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=%0d) ===", OUTPUT_WIDTH);
    all_tests_passed_flag = 1'b1; 
    test_id_counter = 0;
    num_errors = 0;
    
    // Initialize
    window_valid = 0;
    weight_valid = 0;
    multi_channel_window_in = 0;
    multi_channel_weight_in = 0;
    
    #10;
    
    // Test 1
    $display("Test Description: Simple Positive Values (1*1, sum 27)");
    multi_channel_window_in = {27{8'd1}}; 
    multi_channel_weight_in = {27{8'd1}}; 
    window_valid = 1;
    weight_valid = 1;
    #1; 
    check_and_report(27, 1'b1);
    
    #10;
    
    // Test 2
    $display("Test Description: Positive Values with Saturation (2*3, raw %0d, sat %0d)", EXPECTED_SUM_TEST2_UNSIGNED_RAW, EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT);
    multi_channel_window_in = {27{8'd2}};
    multi_channel_weight_in = {27{8'd3}};
    #1; 
    check_and_report(EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT, 1'b1);
        
    #10;
    
    // Test 3
    $display("Test Description: Invalid Inputs (both valid_n low)");
    window_valid = 0;
    weight_valid = 0;
    #1;
    check_and_report(0, 1'b0); 
    
    #10;

    // Test 4
    $display("Test Description: Zero Window Data, Non-zero Weights");
    window_valid = 1;
    weight_valid = 1;
    multi_channel_window_in = {27{8'd0}}; 
    multi_channel_weight_in = {27{8'd5}}; 
    #1;
    check_and_report(0, 1'b1);

    #10;

    // Test 5
    $display("Test Description: Non-zero Window, Zero Weight Data");
    multi_channel_window_in = {27{8'd5}}; 
    multi_channel_weight_in = {27{8'd0}}; 
    #1;
    check_and_report(0, 1'b1);

    #10;

    // Test 6
    $display("Test Description: All Zero Inputs");
    multi_channel_window_in = {27{8'd0}}; 
    multi_channel_weight_in = {27{8'd0}}; 
    #1;
    check_and_report(0, 1'b1);

    #10;
    
    // Test 7
    $display("Test Description: Large values (no saturation with 20-bit output)");
    multi_channel_window_in = {27{8'd5}}; 
    multi_channel_weight_in = {27{8'd5}}; 
    #1;
    check_and_report(27*5*5, 1'b1);  // 27*25 = 675, well within 20-bit range

    #10;
    
    // Test 8
    $display("Test Description: Max Val Inputs (Win=%d, Wgt=%d), should saturate to %d", MAX_ELEMENT_VAL_TB, MAX_WEIGHT_ELEMENT_VAL_TB, MAX_UNSIGNED_OUT_VAL);
    multi_channel_window_in = {27{{DATA_WIDTH{1'b1}}}};
    multi_channel_weight_in = {27{{WEIGHT_WIDTH{1'b1}}}};
    #1;
    // 27 * 255 * 255 = 1,759,725, which exceeds 20-bit max (1,048,575), so should saturate
    check_and_report(MAX_UNSIGNED_OUT_VAL, 1'b1);

    #10;
    
    // Test 8.5: Test 20-bit range capability
    $display("Test Description: Medium values to test 20-bit range (100*100, sum 270000)");
    multi_channel_window_in = {27{8'd100}}; 
    multi_channel_weight_in = {27{8'd100}}; 
    #1;
    check_and_report(27*100*100, 1'b1);  // 27*10000 = 270000, well within 20-bit range

    #10;

    // Test 9: Window valid toggles
    $display("--- Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) ---");
    multi_channel_window_in = {27{8'd1}};
    multi_channel_weight_in = {27{8'd1}};
    weight_valid = 1; 
    
    $display("  Sub-Test Description: WinValid=1 (Start)");
    window_valid = 1; #1; check_and_report(27, 1'b1);
    $display("  Sub-Test Description: WinValid=0");
    window_valid = 0; #1; check_and_report(0,  1'b0);
    $display("  Sub-Test Description: WinValid=1 (End)");
    window_valid = 1; #1; check_and_report(27, 1'b1);

    #10;

    // Test 10: Weight valid toggles
    $display("--- Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) ---");
    window_valid = 1; 
    // inputs are still 1s
    
    $display("  Sub-Test Description: WeightValid=1 (Start)");
    weight_valid = 1; #1; check_and_report(27, 1'b1);
    $display("  Sub-Test Description: WeightValid=0");
    weight_valid = 0; #1; check_and_report(0,  1'b0);
    $display("  Sub-Test Description: WeightValid=1 (End)");
    weight_valid = 1; #1; check_and_report(27, 1'b1);
    
    #10;

    // Final Summary
    $display("==================================================");
    if (all_tests_passed_flag) begin
        $display("FINAL STATUS: SUCCESS! All %0d UNSIGNED Combinational MultAcc tests passed!", test_id_counter);
    end else begin
        $display("FINAL STATUS: FAILED. %0d out of %0d UNSIGNED Combinational MultAcc tests did not pass.", num_errors, test_id_counter);
    end
    $display("==================================================");
    
    $finish;
end

endmodule 

结果

window模块每个周期传递数据,因而采用组合逻辑实现卷积运算。当输入数据同时有效,也就是window_valid和weight_valid同时为高时,mult_acc_com进行运算,conv_valid拉高,如下图所示

在这里插入图片描述

输出打印结果:

=Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=20) =
Test Description: Simple Positive Values (1*1, sum 27)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 1: Status: PASSED

Test Description: Positive Values with Saturation (2*3, raw 162, sat 162)
Expected: conv_valid=1, conv_out= 162
Actual: conv_valid=1, conv_out= 162

Test ID 2: Status: PASSED

Test Description: Invalid Inputs (both valid_n low)
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0

Test ID 3: Status: PASSED

Test Description: Zero Window Data, Non-zero Weights
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0

Test ID 4: Status: PASSED

Test Description: Non-zero Window, Zero Weight Data
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0

Test ID 5: Status: PASSED

Test Description: All Zero Inputs
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0

Test ID 6: Status: PASSED

Test Description: Large values (no saturation with 20-bit output)
Expected: conv_valid=1, conv_out= 675
Actual: conv_valid=1, conv_out= 675

Test ID 7: Status: PASSED

Test Description: Max Val Inputs (Win= 255, Wgt= 255), should saturate to 1048575
Expected: conv_valid=1, conv_out=1048575
Actual: conv_valid=1, conv_out=1048575

Test ID 8: Status: PASSED

Test Description: Medium values to test 20-bit range (100*100, sum 270000)
Expected: conv_valid=1, conv_out= 270000
Actual: conv_valid=1, conv_out= 270000

Test ID 9: Status: PASSED

— Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WinValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 10: Status: PASSED

Sub-Test Description: WinValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0

Test ID 11: Status: PASSED

Sub-Test Description: WinValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 12: Status: PASSED

— Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WeightValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 13: Status: PASSED

Sub-Test Description: WeightValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0

Test ID 14: Status: PASSED

Sub-Test Description: WeightValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 15: Status: PASSED


网站公告

今日签到

点亮在社区的每一天
去签到