mmdet 加载预训练模型多卡训练过程中,存在显卡占用显存不均匀

发布于:2024-12-06 ⋅ 阅读:(25) ⋅ 点赞:(0)

1. 问题描述

基于mmdet https://github.com/open-mmlab/mmdetection代码仓库,修改了自己的检测代码,加载了预训练模型,进行分布式训练。

在训练过程中,出现了显卡的占用显存不均匀的问题。

如图所示,可以看到显卡2 占用了更多的显存。

2. 分析

查找了网上的资料[1,2],发现这个问题在于load模型的时候,直接load的话会导致参数加载到之前保存模型的device上(大部分情况下应该是只用cuda0去保存),这里可以将load函数加一个参数为map_location解决:

model_weights = torch.load(path, map_location='cpu')

3. debug

3.1 代码逻辑

mmengin/runner源代码中,如果没有指定设备那就会按照下面的逻辑得到device,然这种方式得到的device是'cuda'

因为没有指定显卡号,一般是在cuda:0上占用的显存多,参考[1]。

我遇到的问题是,占用显卡更大的显卡不应当是卡0,任意一张卡都有可能出现显存占用更多的显存。我也还没完全搞懂这个随机的逻辑。

DEVICE = 'cpu'
if is_npu_available():
    DEVICE = 'npu'
elif is_cuda_available():
    DEVICE = 'cuda'
elif is_mlu_available():
    DEVICE = 'mlu'
elif is_mps_available():
    DEVICE = 'mps'
elif is_dipu_available():
    DEVICE = 'dipu'


def get_device() -> str:
    """Returns the currently existing device type.

    Returns:
        str: cuda | npu | mlu | mps | cpu.
    """
    return DEVICE

3.2 修改逻辑

既然代码存在指定cuda显卡的不确定性,那么我们干脆直接把模型加载在内存中。具体操作是在传入的参数中指定cpu,这样就不会出现某个显卡占的多了。

3.3 具体的修改如下

把Runner类这个load_or_resume函数修改

(1)

self.resume(resume_from) 

改为

self.resume(resume_from,map_location=torch.device('cpu'))

(2)

self.load_checkpoint(self._load_from) 

改为

self.load_checkpoint(self._load_from, map_location=torch.device('cpu'))

3.4 改完的代码如下所示。

    def load_or_resume(self) -> None:
        """load or resume checkpoint."""
        if self._has_loaded:
            return None

        # decide to load from checkpoint or resume from checkpoint
        resume_from = None
        if self._resume and self._load_from is None:
            # auto resume from the latest checkpoint
            resume_from = find_latest_checkpoint(self.work_dir)
            self.logger.info(
                f'Auto resumed from the latest checkpoint {resume_from}.')
        elif self._resume and self._load_from is not None:
            # resume from the specified checkpoint
            resume_from = self._load_from

        if resume_from is not None:
            ## 2024.12.4
            ## add ,map_location=torch.device('cpu')
            self.resume(resume_from,map_location=torch.device('cpu'))
            # self.resume(resume_from)
            self._has_loaded = True
        elif self._load_from is not None:
            ## add ,map_location=torch.device('cpu')
            self.load_checkpoint(self._load_from, map_location=torch.device('cpu'))
            # self.load_checkpoint(self._load_from)
            self._has_loaded = True

4. 结果

如图所示,问题解决了。

5. 参考


网站公告

今日签到

点亮在社区的每一天
去签到