CentOS 7服务器上快速安装mamba函数库

发布于:2025-04-07 ⋅ 阅读:(23) ⋅ 点赞:(0)

本次预配置虚拟环境为cuda 11.8+torch 2.2.2+python 3.10

1. 创建conda虚拟环境:conda create -n mamba python=3.10

    激活环境:conda activate mamba

2. 安装Pytorch环境: 

conda install pytorch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 pytorch-cuda=11.8 -c pytorch -c nvidia

3. 安装conv1d库(可自选版本)

pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

4. 安装mamba_ssm库【可自选版本】

pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

注:在选择causal_conv1d和mamba_ssm的时候,如果不知道自己的torch和cuda是什么版本可以打印输出:

import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

5. 测试:

from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from mamba_ssm.ops.triton.selective_state_update import selective_state_update

6. 通过程序测试
mamba_ssm测试: 

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim,  # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,  # Local convolution width
    expand=2,  # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y)
print(y.shape)

输出:

tensor([[[-0.0035,  0.0065,  0.0156,  ...,  0.0144,  0.0181, -0.0319],
         [ 0.0144, -0.0009, -0.0034,  ...,  0.0666,  0.0301,  0.0096],
         [-0.0102, -0.0006,  0.0264,  ..., -0.0478,  0.0402, -0.0818],
         ...,
         [-0.0177, -0.0073,  0.0390,  ..., -0.0340, -0.0025, -0.0271],
         [ 0.0138,  0.0031, -0.0223,  ..., -0.0278,  0.0072, -0.0143],
         [-0.0076, -0.0186,  0.0079,  ..., -0.0062,  0.0021, -0.0283]],

        [[ 0.0079,  0.0089, -0.0111,  ..., -0.0010,  0.0097,  0.0033],
         [ 0.0176,  0.0140,  0.0087,  ..., -0.0223,  0.0294, -0.0158],
         [-0.0149,  0.0228,  0.0126,  ...,  0.0382, -0.0015, -0.0182],
         ...,
         [ 0.0322, -0.0258,  0.0312,  ..., -0.0045, -0.0036, -0.0509],
         [ 0.0164,  0.0001, -0.0060,  ...,  0.0722,  0.0058,  0.0061],
         [ 0.0365,  0.0176,  0.0470,  ...,  0.0103,  0.0257,  0.0039]]],
       device='cuda:0', grad_fn=<MambaInnerFnBackward>)

        torch.Size([2, 64, 16])
 

mamba_ssm 2测试:

from mamba_ssm import Mamba2
import torch

batch, length, dim = 2, 64, 512

x = torch.randn(batch, length, dim).to("cuda")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    # make sure d_model * expand / headdim = multiple of 8
    d_model=dim,  # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,  # Local convolution width
    expand=2,  # Block expansion factor
    headdim=64,  # default 64
).to("cuda")

y = model(x)
assert y.shape == x.shape
print(y)
print(y.shape)

tensor([[[-0.7895, -0.0058, -0.6374,  ..., -0.1971, -0.1220,  0.9073],
         [-0.2546,  0.3763, -0.6773,  ..., -0.2671,  0.4855, -0.7709],
         [ 0.4188,  0.3607,  0.7131,  ...,  0.9439, -0.0798, -1.2252],
         ...,
         [-1.3136, -1.1002, -0.5330,  ...,  1.5189,  0.2091,  1.0726],
         [-0.4520, -0.6626, -0.3810,  ...,  0.3964,  0.0947,  0.7275],
         [-0.3024, -0.2375,  0.2435,  ...,  0.4073,  0.4688,  0.6197]],

        [[ 0.0257,  0.7625,  0.7594,  ...,  0.3531,  0.3276,  0.4292],
         [-1.0428,  0.8166,  0.1294,  ...,  0.8236,  0.0515, -0.3141],
         [ 0.1267, -0.6214,  0.1667,  ..., -0.4576,  0.7774,  0.7242],
         ...,
         [ 0.0609, -0.0283,  0.4718,  ...,  0.7035,  0.2011,  0.2541],
         [ 0.3708, -0.0039,  0.2280,  ...,  0.9191, -0.6267, -0.2572],
         [-0.2993,  0.0933,  0.2601,  ...,  0.9123, -1.0403,  0.7488]]],
       device='cuda:0', grad_fn=<MambaSplitConv1dScanCombinedFnBackward>)
        torch.Size([2, 64, 512])

参考:
1.https://zhuanlan.zhihu.com/p/27156724975
2.https://zhuanlan.zhihu.com/p/25916604332