MATLAB C++语言编写MEX函数示例:求解本源根

发布于:2024-08-03 ⋅ 阅读:(102) ⋅ 点赞:(0)

MATLAB的MEX函数可以通过编译的机器代码,替代低效的脚本语言提升运行效率(以及隐藏原始代码保护知识产权)。MEX函数最初支持C语言编写,从2018a开始支持基于C++11的“现代”C++编写MEX,并实现更多“现代”特性(主要是程序内存安全性)。目前,MATLAB官方已不推荐继续用传统C语言编写新的MEX函数:C Matrix API - MATLAB & Simulink - MathWorks 中国
在这个博客,我展示一个搜索质数本源根的程序,详细说明如何把C++代码嵌入MEX函数中。

首先说明什么是本源根:本源根是离散数学上的一个概念,考虑对一个质数 p p p,若正整数 a a a的各次幂除以 p p p的余数正好可以产生 1 1 1 p − 1 p-1 p1的所有整数,即 a  mod  p a \text{ mod } p a mod p a 2  mod  p a^2 \text{ mod } p a2 mod p、…、 a p − 1  mod  p a^{p-1} \text{ mod } p ap1 mod p各不相同,那么 a a a就是 p p p的一个本源根。本源根在非对称密码学上非常有用,而且显而易见,对于大质数,求出全部本源根需要穷举,这种穷举就非常适合通过C/C++编程来优化。

首先给出代码:

#include "mex.hpp"
#include "mexAdapter.hpp"

using namespace matlab::data;
using matlab::mex::ArgumentList;

class MexFunction : public matlab::mex::Function {
public:
    void operator()(ArgumentList outputs, ArgumentList inputs) {
        std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
        ArrayFactory factory;
        // Validate arguments
        checkArguments(outputs, inputs);

        TypedArray<double> inputArray = std::move(inputs[0]);
        int num =(int)inputArray[0];
        if (num <= 2){
            matlabPtr->feval(u"error", 0, 
                std::vector<Array>({ factory.createScalar("Input must be an integer larger than 2") }));
        }

        // 算法实现
        int i, j, cur_num;
        int search_list[num];
        bool isPriRoot;
        std::vector<int> vec_priroots;
        for (i=2; i<num-1; i++){
            isPriRoot = true;
            for (j=0; j<num; j++) search_list[j] = 0;
            cur_num = i;
            for (j=0; j<num - 1; j++){
                search_list[cur_num]++;
                if (search_list[cur_num] >= 2) {
                    isPriRoot = false;
                    break;
                }
                cur_num = (cur_num * i) % num;
            }
            if (isPriRoot) vec_priroots.push_back(i);
        }

        // Assign outputs
        TypedArray<double> pri_root_Array = factory.createArray<double>({vec_priroots.size(),1});
        for (i=0; i<vec_priroots.size(); i++){
            pri_root_Array[i] = vec_priroots[i];
        }
        outputs[0] = pri_root_Array;
    }

    void checkArguments(ArgumentList outputs, ArgumentList inputs) {
        std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
        ArrayFactory factory;
        
        if (inputs[0].getType() != ArrayType::DOUBLE ||
            inputs[0].getType() == ArrayType::COMPLEX_DOUBLE ||
            inputs[0].getNumberOfElements() != 1)
        {
            matlabPtr->feval(u"error", 0, 
                std::vector<Array>({ factory.createScalar("Input must be an integer larger than 2") }));
        }

        if (outputs.size() > 1) {
            matlabPtr->feval(u"error", 0, 
                std::vector<Array>({ factory.createScalar("Only one output is returned") }));
        }
    }
};

从头看起,MEX C++严格来说不是在编写函数, 而是在编写一个名为MexFunction的类,继承自matlab::mex::Function,然后重载这个类的括号运算符operator()

void operator()(ArgumentList outputs, ArgumentList inputs)

与C语言的MEX接口不同,这里通过两个matlab::mex::ArgumentList容器,分别传入函数的输入inputs和输出outputs。它们传入的数据类型为TypedArray<T>,可以是MATLAB的二维或多维矩阵。
在检验输入数据符合算法要求后(checkArguments(outputs, inputs),这里略过该函数的实现),我们用std::move接收第一个输入Array的引用(这是官方推荐做法,避免因类型转换多生成副本),然后将该Array第一个数据取出,就是我们算法的输入,即待求本源根的质数 p p p

TypedArray<double> inputArray = std::move(inputs[0]);
int num =(int)inputArray[0];

这里需要说明,如果输入参数不止一个,可以用如下循环逐个访问类型为输入数据中的参数。为了避免复制过大的数据矩阵,迭代器采用了引用类型:

TypedArray<double> doubleArray = std::move(inputs[0]);
for (auto& elem : doubleArray) {
            // do something with elem, the type of elem is double
}

再之后的算法实现就是搜索每一个小于 p p p的整数的各次幂余数是否重复,这个过程很简单,我也就用C语言实现(不推荐!),不多赘述。

现在说明一些不太需要深究的内容。本代码声明了两个变量:

std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
ArrayFactory factory;

其中matlabPtr 主要的作用是通过feval方法调用MATLAB的函数,在本段代码中只用于输出错误信息;matlab::data::ArrayFactory类可以通过模板产生类型为TypedArray<T>的数据,故我们在最后调用该类组装输出数据:

TypedArray<double> pri_root_Array = factory.createArray<double>({vec_priroots.size(),1});
for (i=0; i<vec_priroots.size(); i++){
     pri_root_Array[i] = vec_priroots[i];
}

最后,通过

outputs[0] = pri_root_Array;

将组装的数据交给输出容器的第一个数,大功告成。

可以看到,MATLAB通过容器向MEX函数(对象)传入传出数据。我们在编写MEX C++代码的时候,只要能从ArgumentList inputs取出输入数据,再把输出结果放进ArgumentList outputs,其他的实现和我们平时编写C++代码没有任何区别。特别值得说明的是,各种C++11的标准库容器(比如std::vector)都可以在MEX函数中使用,这也方便我们移植已有算法。


网站公告

今日签到

点亮在社区的每一天
去签到