动态规划多阶段报童模型,c++ 实现, java 实现

发布于:2025-03-03 ⋅ 阅读:(17) ⋅ 点赞:(0)

借助 chaptgpt 和 deepseek,成功实现了c++上的多阶段报童模型的动态规划。花费了几天,将以前的 java 程序用 c++ 实现。

文章目录

总结:

  • c++ 还是比 java 快点,30个阶段快了零点几秒
  • c++ 使用了 unordered_map ,存储递归数据
  • java 使用了 ConcurrentSkipListMap 存储递归数据,这个可以按照排序器自动排序
  • 若 c++ 也用可以排序的 map,速度反而比 java 慢了。理论上c++会快,但估计需要其他的一些功能设置
  • c++ 运行时要开启 -o2 或 -o3 优化加速

在这里插入图片描述

在这里插入图片描述

C++ 代码

//
// Created by Zhen Chen on 2025/2/26.
//

#include <chrono>
#include <iostream>
#include <limits>
#include <boost/functional/hash.hpp>
#include <unordered_map>
#include <iostream>
#include <map>
#include <span>
#include <csignal>

class State {
    int period{}; // c++11, {} 值初始化,默认为 0
    double initialInventory{};

public:
    State();

    explicit State(int period, double initialInventory);

    [[nodiscard]] double getInitialInventory() const;

    [[nodiscard]] int getPeriod() const;

    void print() const;

    // hashmap must define operator == and a struct to compute hash
    bool operator==(const State &other) const {
        // 需要定义 `==`
        // const MyClass &other	保证 other 参数不可修改
        // const 在函数结尾 保证当前对象(this) 不可修改
        // 不会修改成员变量的方法 都可以在函数声明的结尾添加 const
        return period == other.period && initialInventory == other.initialInventory;
    }

    // 允许哈希结构体访问私有成员
    // friend struct
    friend struct std::hash<State>;

    // define operator < or give a self defined comparator for sorting map
    bool operator<(const State &other) const {
        if (period < other.period) {
            return true;
        }
        if (period == other.period) {
            if (initialInventory < other.initialInventory) {
                return true;
            }
            return false;
        }
        return false;
    }
};


// `std::hash<State>` 需要特化
template<> // 表示模版特化, override 标准库中的 hash 生成函数
struct std::hash<State> {
    // size_t 表示无符号整数
    size_t operator()(const State &s) const noexcept {
        // noexcept 表示这个函数不会抛出异常
        // boost 的哈希计算更安全
        std::size_t seed = 0;
        boost::hash_combine(seed, s.period);
        boost::hash_combine(seed, s.initialInventory);
        return seed;

        // return std::hash<int>()(s.period) ^ std::hash<double>()(s.initialInventory) << 1; // 计算哈希值
        // std::hash<int>() 是一个 std::hash<int> 类型的对象,调用 () 运算符可以计算 obj.id(整数)的哈希值
        // ^(异或)是位运算,不会造成进位,适合合并多个哈希值
        // 这里的 << 1 左移 1 位(相当于乘 2),让哈希值更加分散,避免简单叠加导致哈希冲突
    }
};

State::State() = default;

State::State(const int period, const double initialInventory): period(period), initialInventory(initialInventory) {
};

double State::getInitialInventory() const {
    return initialInventory;
}

int State::getPeriod() const {
    return period;
}

void State::print() const {
    std::cout << "period: " << period << ", ini I: " << initialInventory << std::endl;
}


class ProbabilityMassFunctions {
    double truncatedQuantile;
    double stepSize;
    std::string distributionName;

public:
    ProbabilityMassFunctions(double truncatedQuantile, double stepSize, std::string distributionName);

    // std::string getName();

    void checkName() const;

    static double poissonPMF(int k, double lambda);

    [[nodiscard]] std::vector<std::vector<std::vector<double> > > getPMF(std::span<double> demands) const;

    [[nodiscard]] std::vector<std::vector<std::vector<double> > >
    getPMFPoisson(std::span<double> demands) const;

    static int poissonQuantile(double p, double lambda);

    static double poissonCDF(int k, double lambda);
};

// initializing the class
ProbabilityMassFunctions::ProbabilityMassFunctions(
    const double truncatedQuantile, const double stepSize, std::string distributionName)
    : truncatedQuantile(truncatedQuantile), stepSize(stepSize), distributionName(std::move(distributionName)) {
    checkName();
} // std::move for efficiency passing in string and vector

void ProbabilityMassFunctions::checkName() const {
    auto name = distributionName;
    std::ranges::transform(name, name.begin(), ::tolower);
    if (name != "poisson") {
        std::cout << " distribution not found or to do next for this distribution\n";
        raise(-1);
    }
}

// get the probability mass function value of Poisson
double ProbabilityMassFunctions::poissonPMF(const int k, const double lambda) {
    if (k < 0 || lambda <= 0) return 0.0; // 确保参数合法
    return (std::pow(lambda, k) * std::exp(-lambda)) / std::tgamma(k + 1);
    // tgamma(k+1) is a gamma function, 等同于factorial(k)
}


// get cumulative distribution function value of Poisson
double ProbabilityMassFunctions::poissonCDF(const int k, const double lambda) {
    double cumulative = 0.0;
    double term = std::exp(-lambda);
    for (int i = 0; i <= k; ++i) {
        cumulative += term;
        if (i < k)
            term *= lambda / (i + 1); // 递推计算 P(X=i)
    }

    return cumulative;
}

// get inverse cumulative distribution function value of Poisson
int ProbabilityMassFunctions::poissonQuantile(const double p, const double lambda) {
    int low = 0, high = std::max(100, static_cast<int>(lambda * 3)); // 初始搜索区间
    while (low < high) {
        if (const int mid = (low + high) / 2; poissonCDF(mid, lambda) < p) {
            low = mid + 1;
        } else {
            high = mid;
        }
    }
    return low;
}

// get probability mass function values for each period of Poisson
std::vector<std::vector<std::vector<double> > > ProbabilityMassFunctions::
getPMF(const std::span<double> demands) const {
    if (distributionName == "poisson") {
        return getPMFPoisson(demands);
    }
    return {};
}

// get probability mass function values for each period of Poisson
std::vector<std::vector<std::vector<double> > > ProbabilityMassFunctions::
getPMFPoisson(const std::span<double> demands) const {
    const auto T = demands.size();
    int supportLB[T];
    int supportUB[T];
    for (int i = 0; i < T; ++i) {
        supportUB[i] = poissonQuantile(truncatedQuantile, demands[i]);
        supportLB[i] = poissonQuantile(1 - truncatedQuantile, demands[i]);
    }
    std::vector<std::vector<std::vector<double> > > pmf(T, std::vector<std::vector<double> >());
    for (int t = 0; t < T; ++t) {
        const int demandLength = static_cast<int>((supportUB[t] - supportLB[t] + 1) / stepSize);
        pmf[t] = std::vector<std::vector<double> >(demandLength, std::vector<double>());
        for (int j = 0; j < demandLength; ++j) {
            pmf[t][j] = std::vector<double>(2);
            pmf[t][j][0] = supportLB[t] + j * stepSize;
            const int demand = static_cast<int>(pmf[t][j][0]);
            pmf[t][j][1] = poissonPMF(demand, demands[t]) / (2 * truncatedQuantile - 1);
        }
    }

    return pmf;
}

class NewsvendorDP {
    int T;
    int capacity;
    double stepSize;
    double fixOrderCost;
    double unitVariOrderCost;
    double unitHoldCost;
    double unitPenaltyCost;
    double truncatedQuantile;
    double max_I;
    double min_I;

    std::vector<std::vector<std::vector<double> > > pmf;

     std::unordered_map<State, double> cacheActions{};
     std::unordered_map<State, double> cacheValues{};

//    std::map<State, double> cacheActions{};
//    std::map<State, double> cacheValues{};

public:
    NewsvendorDP(size_t T, int capacity, double stepSize, double fixOrderCost, double unitVariOrderCost,
                 double unitHoldCost, double unitPenaltyCost, double truncatedQuantile, double max_I, double min_I,
                 std::vector<std::vector<std::vector<double> > > pmf);

    [[nodiscard]] std::vector<double> feasibleActions() const;

    [[nodiscard]] State stateTransitionFunction(const State &state, double action, double demand) const;

    [[nodiscard]] double immediateValueFunction(const State &state, double action, double demand) const;

    [[nodiscard]] double getOptAction(const State &tate);

    [[nodiscard]] auto getTable() const;

    double recursion(const State &state);
};

NewsvendorDP::NewsvendorDP(const size_t T, const int capacity,
                           const double stepSize, const double fixOrderCost,
                           const double unitVariOrderCost,
                           const double unitHoldCost, const double unitPenaltyCost,
                           const double truncatedQuantile, const double max_I,
                           const double min_I,
                           std::vector<std::vector<std::vector<double> > > pmf): T(static_cast<int>(T)),
    capacity(capacity),
    stepSize(stepSize),
    fixOrderCost(fixOrderCost),
    unitVariOrderCost(unitVariOrderCost),
    unitHoldCost(unitHoldCost), unitPenaltyCost(unitPenaltyCost), truncatedQuantile(truncatedQuantile),
    max_I(max_I), min_I(min_I), pmf(std::move(pmf)) {
};


std::vector<double> NewsvendorDP::feasibleActions() const {
    const int QNum = static_cast<int>(capacity / stepSize);
    std::vector<double> actions(QNum);
    for (int i = 0; i < QNum; i = i + 1) {
        actions[i] = i * stepSize;
    }
    return actions;
}

State NewsvendorDP::stateTransitionFunction(const State &state, const double action, const double demand) const {
    double nextInventory = state.getInitialInventory() + action - demand;
    if (state.getPeriod() == 1) {
        (void) nextInventory;
    }
    if (nextInventory > 0) {
        (void) nextInventory;
    }
    nextInventory = nextInventory > max_I ? max_I : nextInventory;
    nextInventory = nextInventory < min_I ? min_I : nextInventory;

    const int nextPeriod = state.getPeriod() + 1;
    // C++11 引入了统一的列表初始化(Uniform Initialization),鼓励使用大括号 {} 初始化类
    const auto newState = State{nextPeriod, nextInventory};

    return newState;
}

double NewsvendorDP::immediateValueFunction(const State &state, const double action, const double demand) const {
    const double fixCost = action > 0 ? fixOrderCost : 0;
    const double variCost = action * unitVariOrderCost;
    double nextInventory = state.getInitialInventory() + action - demand;
    nextInventory = nextInventory > max_I ? max_I : nextInventory;
    nextInventory = nextInventory < min_I ? min_I : nextInventory;
    const double holdCost = std::max(unitHoldCost * nextInventory, 0.0);
    const double penaltyCost = std::max(-unitPenaltyCost * nextInventory, 0.0);

    const double totalCost = fixCost + variCost + holdCost + penaltyCost;
    return totalCost;
}

double NewsvendorDP::getOptAction(const State &state) {
    return cacheActions[state];
}

auto NewsvendorDP::getTable() const {
    size_t stateNums = cacheActions.size();
    std::vector<std::vector<double> > table(stateNums, std::vector<double>(3));
    int index = 0;
    for (const auto &[fst, snd]: cacheActions) {
        table[index][0] = fst.getPeriod();
        table[index][1] = fst.getInitialInventory();
        table[index][2] = snd;
        index++;
    }
    return table;
}

double NewsvendorDP::recursion(const State &state) {
    double bestQ = 0.0;
    double bestValue = std::numeric_limits<double>::max();
    const std::vector<double> actions = feasibleActions();
    for (const double action: feasibleActions()) {
        double thisValue = 0;
        for (auto demandAndProb: pmf[state.getPeriod() - 1]) {
            thisValue += demandAndProb[1] * immediateValueFunction(state, action, demandAndProb[0]);
            if (state.getPeriod() < T) {
                auto newState = stateTransitionFunction(state, action, demandAndProb[0]);
                (void) action;
                if (cacheValues.contains(newState)) {
                    // some issues here
                    thisValue += demandAndProb[1] * cacheValues[newState];
                } else {
                    thisValue += demandAndProb[1] * recursion(newState);
                }
            }
        }
        if (thisValue < bestValue) {
            bestValue = thisValue;
            bestQ = action;
        }
    }
    cacheActions[state] = bestQ;
    cacheValues[state] = bestValue;
    return bestValue;
}


int main() {
    std::vector<double> demands(30, 20);
    const std::string distribution_type = "poisson";
    constexpr int capacity = 100; // maximum ordering quantity
    constexpr double stepSize = 1.0;
    constexpr double fixOrderCost = 0;
    constexpr double unitVariOderCost = 1;
    constexpr double unitHoldCost = 2;
    constexpr double unitPenaltyCost = 10;
    constexpr double truncQuantile = 0.9999; // truncated quantile for the demand distribution
    constexpr double maxI = 500; // maximum possible inventory
    constexpr double minI = -300; // minimum possible inventory


    const auto pmf = ProbabilityMassFunctions(truncQuantile, stepSize, distribution_type).getPMF(demands);
    const size_t T = demands.size();
    auto model = NewsvendorDP(T, capacity, stepSize, fixOrderCost, unitVariOderCost, unitHoldCost, unitPenaltyCost,
                              truncQuantile, maxI, minI, pmf);

    const auto initialState = State(1, 0);
    const auto start_time = std::chrono::high_resolution_clock::now();
    const auto optValue = model.recursion(initialState);
    const auto end_time = std::chrono::high_resolution_clock::now();
    const std::chrono::duration<double> duration = end_time - start_time;
    std::cout << "planning horizon is " << T << " periods" << std::endl;
    std::cout << "running time of C++ is " << duration << std::endl;
    std::cout << "Final optimal value is: " << optValue << std::endl;
     const auto optQ = model.getOptAction(initialState);
     std::cout << "Optimal Q is: " << optQ << std::endl;
    // auto table = model.getTable();
    return 0;
}

Java 代码


import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.function.Function;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

public class CLSP {
    double[][][] pmf;

    public CLSP(double[][][] pmf) {
        this.pmf = pmf;
    }

    class State {
        int period;
        double initialInventory;

        public State(int period, double initialInventory) {
            this.period = period;
            this.initialInventory = initialInventory;
        }

        public double[] getFeasibleActions() {
            return actionGenerator.apply(this);
        }

        @Override
        public int hashCode() {
            String hash = "";
            hash = hash + period + initialInventory;
            return hash.hashCode();
        }

        @Override
        public boolean equals(Object o) {
            if (o instanceof State)
                return ((State) o).period == this.period &&
                        ((State) o).initialInventory == this.initialInventory;
            else
                return false;
        }

        @Override
        public String toString() {
            return "period = " + period + ", " + "initialInventory = " + initialInventory;
        }
    }

    Function<State, double[]> actionGenerator;


    interface StateTransitionFunction<S, A, R, S2> {
        public S2 apply(S s, A a, R r);
    }

    StateTransitionFunction<State, Double, Double, State> stateTransition;


    interface ImmediateValueFunction<S, A, R, V> {
        public V apply(S s, A a, R r);
    }

    ImmediateValueFunction<State, Double, Double, Double> immediateValue;

    Comparator<State> keyComparator = (o1, o2) -> o1.period > o2.period ? 1 :
            o1.period == o2.period ? Double.compare(o1.initialInventory, o2.initialInventory) : -1;

//
    ConcurrentSkipListMap<State, Double> cacheActions = new ConcurrentSkipListMap<>(keyComparator);
    ConcurrentSkipListMap<State, Double> cacheValues = new ConcurrentSkipListMap<>(keyComparator);

    double f(State state) {
        return cacheValues.computeIfAbsent(state, s -> {
//            double val = Arrays.stream(s.getFeasibleActions())
//                    .map(orderQty -> Arrays.stream(pmf[s.period - 1])
//                            .mapToDouble(p -> p[1] * immediateValue.apply(s, orderQty, p[0]) +
//                                    (s.period < pmf.length ?
//                                            p[1] * f(stateTransition.apply(s, orderQty, p[0])) : 0))
//                            .sum())
//                    .min()
//                    .getAsDouble();
//            double bestOrderQty = Arrays.stream(s.getFeasibleActions())
//                    .filter(orderQty -> Arrays.stream(pmf[s.period - 1])
//                            .mapToDouble(p -> p[1] * immediateValue.apply(s, orderQty, p[0]) +
//                                    (s.period < pmf.length ?
//                                            p[1] * f(stateTransition.apply(s, orderQty, p[0])) : 0))
//                            .sum() == val)
//                    .findAny()
//                    .getAsDouble();
//            cacheActions.putIfAbsent(s, bestOrderQty);
//            return val;
//        });
//    }

            double[] feasibleActions = state.getFeasibleActions();
            double[][] dAndP = pmf[state.period - 1]; // demandAndPossibility
            double[] QValues = new double[feasibleActions.length];
            double val = Double.MAX_VALUE;

            double bestOrderQty = 0;
            for (int i = 0; i < feasibleActions.length; i++) {
                double orderQty = feasibleActions[i];
                double thisQValue = 0;
                for (int j = 0; j < dAndP.length; j++) {
                    thisQValue += dAndP[j][1] * immediateValue.apply(state, orderQty, dAndP[j][0]);
                    if (state.period < pmf.length) {
                        State newState = stateTransition.apply(state, orderQty, dAndP[j][0]);
                        thisQValue += dAndP[j][1] * f(newState);
                        }
                    }
                QValues[i] = thisQValue;

                if (QValues[i] < val) {
                    val = QValues[i];
                    bestOrderQty = orderQty;
                }

            }
            this.cacheActions.putIfAbsent(state, bestOrderQty);
//            cacheValues.put(state, val);
            return val;
        });
    }



    public static void main(String[] args) {


        double initialInventory = 0;
        double[] meanDemand = new double[30];
        Arrays.fill(meanDemand, 20);

        double truncationQuantile = 0.9999;
        double stepSize = 1;
        double minState = -150;
        double maxState = 300;
        int T = meanDemand.length;

        double fixedOrderingCost = 0;
        double proportionalOrderingCost = 1;
        double holdingCost = 2;
        double penaltyCost = 10;

        int maxOrderQuantity = 100;

        Distribution[] distributions = IntStream.iterate(0, i -> i + 1)
                .limit(T)
                .mapToObj(i -> new PoissonDist(meanDemand[i]))
//	                                              .mapToObj(i -> new UniformDist(0, meanDemand[i]))
                //.mapToObj(i -> new NormalDist(meanDemand[i], 0.25 * meanDemand[i]))
                .toArray(Distribution[]::new); // replace for loop
        double[] supportLB = IntStream.iterate(0, i -> i + 1)
                .limit(T)
                .mapToDouble(i -> distributions[i].inverseF(1 - truncationQuantile))
                .toArray();
        double[] supportUB = IntStream.iterate(0, i -> i + 1)
                .limit(T)
                .mapToDouble(i -> distributions[i].inverseF(truncationQuantile))
                .toArray();
        double[][][] pmf = new double[T][][];
        for (int i = 0; i < T; i++) {
            int demandLength = (int) ((supportUB[i] - supportLB[i] + 1) / stepSize);
            pmf[i] = new double[demandLength][];
            // demand values are all integers
            for (int j = 0; j < demandLength; j++) {
                pmf[i][j] = new double[2];
                pmf[i][j][0] = supportLB[i] + j * stepSize;
                int demand = (int) pmf[i][j][0];
                if (distributions[0] instanceof DiscreteDistribution) {
                     // double probabilitySum = distributions[i].cdf(supportUB[i]) - distributions[i].cdf(supportLB[i]);
                    double probabilitySum = 2 * truncationQuantile - 1;
                    pmf[i][j][1] = ((DiscreteDistribution) distributions[i]).prob(demand) / probabilitySum;
                } else {
                    double probabilitySum = distributions[i].cdf(supportUB[i] + 0.5 * stepSize)
                            - distributions[i].cdf(supportLB[i] - 0.5 * stepSize);
                    pmf[i][j][1] = (distributions[i].cdf(pmf[i][j][0] + 0.5 * stepSize)
                            - distributions[i].cdf(pmf[i][j][0] - 0.5 * stepSize)) / probabilitySum;
                }
            }
        }

        CLSP inventory = new CLSP(pmf);

        inventory.actionGenerator = s -> {
            return DoubleStream.iterate(0, i -> i + stepSize).limit(maxOrderQuantity + 1).toArray();
        };

        inventory.stateTransition = (state, action, randomDemand) -> {
            double nextInventory = state.initialInventory + action - randomDemand;
            nextInventory = nextInventory > maxState ? maxState : nextInventory;
            nextInventory = nextInventory < minState ? minState : nextInventory;
            return inventory.new State(state.period + 1, nextInventory);
        };


        inventory.immediateValue = (state, action, randomDemand) ->
        {
            double fixedCost = action > 0 ? fixedOrderingCost : 0;
            double variableCost = proportionalOrderingCost * action;
            double inventoryLevel = state.initialInventory + action - randomDemand;
            double holdingCosts = holdingCost * Math.max(inventoryLevel, 0);
            double penaltyCosts = penaltyCost * Math.max(-inventoryLevel, 0);
            double totalCosts = fixedCost + variableCost + holdingCosts + penaltyCosts;
            return totalCosts;
        };

        int period = 1;
        State initialState = inventory.new State(period, initialInventory);
        long currTime2 = System.currentTimeMillis();

        double finalValue = inventory.f(initialState);
        double time = (System.currentTimeMillis() - currTime2) / 1000.000;
        System.out.println("planning horizon is " + meanDemand.length + " periods");
        System.out.println("running time of Java is " + time + " s");

        System.out.println("final optimal expected value is: " + finalValue);

        double optQ = inventory.cacheActions.get(inventory.new State(period, initialInventory));
        System.out.println("optimal order quantity in the first priod is : " + optQ);

    }
}