内存分配抽象定义二

发布于:2024-06-22 ⋅ 阅读:(102) ⋅ 点赞:(0)
#include <iostream>
#include <memory>
#include <stdexcept>
#include <cuda_runtime.h>

template <typename T>
struct MemoryDeleter {
	bool UseCUDA; // 成员变量用于标记是否使用CUDA

	MemoryDeleter(bool useCUDA) : UseCUDA(useCUDA) {} // 构造函数初始化UseCUDA

	void operator()(T* ptr) {
		if (UseCUDA) {
			cudaError_t cudaStatus = cudaFree(ptr);
			if (cudaStatus != cudaSuccess) {
				std::cerr << "CUDA memory free error: " << cudaGetErrorString(cudaStatus) << std::endl;
			}
		}
		else {
			delete[] ptr; // 使用delete[]释放CPU内存
		}
	}
};

template <typename T, bool UseCUDA>
using SharedMemoryPtr = std::conditional_t<UseCUDA, std::shared_ptr<T>, std::unique_ptr<T[], MemoryDeleter<T>>>;

template <typename T, bool UseCUDA>
class MemoryManager {
public:
	static SharedMemoryPtr<T, UseCUDA> Allocate(size_t size);
	static void Set(T* ptr, int value, size_t size);
	static void Copy(T* dest, const T* src, size_t size);
};

template <typename T, bool UseCUDA>
SharedMemoryPtr<T, UseCUDA> MemoryManager<T, UseCUDA>::Allocate(size_t size) {
	T* ptr = nullptr;
	if constexpr (UseCUDA) {
		cudaMalloc((T**)&ptr, size * sizeof(T));
	}
	else {
		ptr = new T[size];
	}
	return SharedMemoryPtr<T, UseCUDA>(ptr, MemoryDeleter<T>(UseCUDA));
}

template <typename T, bool UseCUDA>
void MemoryManager<T, UseCUDA>::Set(T* ptr, int value, size_t size) {
	if constexpr (UseCUDA) {
		cudaMemset(ptr, value, size * sizeof(T));
	}
	else {
		for (size_t i = 0; i < size; ++i) {
			ptr[i] = static_cast<T>(value);
		}
	}
}

template <typename T, bool UseCUDA>
void MemoryManager<T, UseCUDA>::Copy(T* dest, const T* src, size_t size) {
	if constexpr (UseCUDA) {
		cudaMemcpy(dest, src, size * sizeof(T), cudaMemcpyHostToDevice);
	}
	else {
		memcpy(dest, src, size * sizeof(T));
	}
}

int main() {
	int size = 512 * 512 * 500;
	SharedMemoryPtr<float, true> ptr = MemoryManager<float, true>::Allocate(size);

	int value = 0;
	MemoryManager<float, true>::Set(ptr.get(), value, size);

	// float hostData[512 * 512 * 100]={ 0 };
	float* hostData = new float[size];
	for (int i = 0; i < size; ++i) {
		hostData[i] = static_cast<float>(i);
	}

	MemoryManager<float, true>::Copy(ptr.get(), hostData, size);

	//ptr.reset();
	return 0;
}


网站公告

今日签到

点亮在社区的每一天
去签到