本次预配置虚拟环境为cuda 11.8+torch 2.2.2+python 3.10
1. 创建conda虚拟环境:conda create -n mamba python=3.10
激活环境:conda activate mamba
2. 安装Pytorch环境:
conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=11.8 -c pytorch -c nvidia
3. 安装conv1d库(可自选版本):
pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
4. 安装mamba_ssm库【可自选版本】:
pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
注:在选择causal_conv1d和mamba_ssm的时候,如果不知道自己的torch和cuda是什么版本可以打印输出:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
5. 测试:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
6. 通过程序测试
mamba_ssm测试:
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y)
print(y.shape)
输出:
tensor([[[-0.0035, 0.0065, 0.0156, ..., 0.0144, 0.0181, -0.0319],
[ 0.0144, -0.0009, -0.0034, ..., 0.0666, 0.0301, 0.0096],
[-0.0102, -0.0006, 0.0264, ..., -0.0478, 0.0402, -0.0818],
...,
[-0.0177, -0.0073, 0.0390, ..., -0.0340, -0.0025, -0.0271],
[ 0.0138, 0.0031, -0.0223, ..., -0.0278, 0.0072, -0.0143],
[-0.0076, -0.0186, 0.0079, ..., -0.0062, 0.0021, -0.0283]],
[[ 0.0079, 0.0089, -0.0111, ..., -0.0010, 0.0097, 0.0033],
[ 0.0176, 0.0140, 0.0087, ..., -0.0223, 0.0294, -0.0158],
[-0.0149, 0.0228, 0.0126, ..., 0.0382, -0.0015, -0.0182],
...,
[ 0.0322, -0.0258, 0.0312, ..., -0.0045, -0.0036, -0.0509],
[ 0.0164, 0.0001, -0.0060, ..., 0.0722, 0.0058, 0.0061],
[ 0.0365, 0.0176, 0.0470, ..., 0.0103, 0.0257, 0.0039]]],
device='cuda:0', grad_fn=<MambaInnerFnBackward>)
torch.Size([2, 64, 16])
mamba_ssm 2测试:
from mamba_ssm import Mamba2
import torch
batch, length, dim = 2, 64, 512
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
# make sure d_model * expand / headdim = multiple of 8
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=64, # default 64
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y)
print(y.shape)
tensor([[[-0.7895, -0.0058, -0.6374, ..., -0.1971, -0.1220, 0.9073],
[-0.2546, 0.3763, -0.6773, ..., -0.2671, 0.4855, -0.7709],
[ 0.4188, 0.3607, 0.7131, ..., 0.9439, -0.0798, -1.2252],
...,
[-1.3136, -1.1002, -0.5330, ..., 1.5189, 0.2091, 1.0726],
[-0.4520, -0.6626, -0.3810, ..., 0.3964, 0.0947, 0.7275],
[-0.3024, -0.2375, 0.2435, ..., 0.4073, 0.4688, 0.6197]],
[[ 0.0257, 0.7625, 0.7594, ..., 0.3531, 0.3276, 0.4292],
[-1.0428, 0.8166, 0.1294, ..., 0.8236, 0.0515, -0.3141],
[ 0.1267, -0.6214, 0.1667, ..., -0.4576, 0.7774, 0.7242],
...,
[ 0.0609, -0.0283, 0.4718, ..., 0.7035, 0.2011, 0.2541],
[ 0.3708, -0.0039, 0.2280, ..., 0.9191, -0.6267, -0.2572],
[-0.2993, 0.0933, 0.2601, ..., 0.9123, -1.0403, 0.7488]]],
device='cuda:0', grad_fn=<MambaSplitConv1dScanCombinedFnBackward>)
torch.Size([2, 64, 512])
参考:
1.https://zhuanlan.zhihu.com/p/27156724975
2.https://zhuanlan.zhihu.com/p/25916604332