1. 问题描述
基于mmdet https://github.com/open-mmlab/mmdetection代码仓库,修改了自己的检测代码,加载了预训练模型,进行分布式训练。
在训练过程中,出现了显卡的占用显存不均匀的问题。
如图所示,可以看到显卡2 占用了更多的显存。
2. 分析
查找了网上的资料[1,2],发现这个问题在于load模型的时候,直接load的话会导致参数加载到之前保存模型的device上(大部分情况下应该是只用cuda0去保存),这里可以将load函数加一个参数为map_location解决:
model_weights = torch.load(path, map_location='cpu')
3. debug
3.1 代码逻辑
mmengin/runner源代码中,如果没有指定设备那就会按照下面的逻辑得到device,然这种方式得到的device是'cuda'
因为没有指定显卡号,一般是在cuda:0上占用的显存多,参考[1]。
我遇到的问题是,占用显卡更大的显卡不应当是卡0,任意一张卡都有可能出现显存占用更多的显存。我也还没完全搞懂这个随机的逻辑。
DEVICE = 'cpu'
if is_npu_available():
DEVICE = 'npu'
elif is_cuda_available():
DEVICE = 'cuda'
elif is_mlu_available():
DEVICE = 'mlu'
elif is_mps_available():
DEVICE = 'mps'
elif is_dipu_available():
DEVICE = 'dipu'
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | npu | mlu | mps | cpu.
"""
return DEVICE
3.2 修改逻辑
既然代码存在指定cuda显卡的不确定性,那么我们干脆直接把模型加载在内存中。具体操作是在传入的参数中指定cpu,这样就不会出现某个显卡占的多了。
3.3 具体的修改如下
把Runner类这个load_or_resume函数修改
(1)
self.resume(resume_from)
改为
self.resume(resume_from,map_location=torch.device('cpu'))
(2)
self.load_checkpoint(self._load_from)
改为
self.load_checkpoint(self._load_from, map_location=torch.device('cpu'))
3.4 改完的代码如下所示。
def load_or_resume(self) -> None:
"""load or resume checkpoint."""
if self._has_loaded:
return None
# decide to load from checkpoint or resume from checkpoint
resume_from = None
if self._resume and self._load_from is None:
# auto resume from the latest checkpoint
resume_from = find_latest_checkpoint(self.work_dir)
self.logger.info(
f'Auto resumed from the latest checkpoint {resume_from}.')
elif self._resume and self._load_from is not None:
# resume from the specified checkpoint
resume_from = self._load_from
if resume_from is not None:
## 2024.12.4
## add ,map_location=torch.device('cpu')
self.resume(resume_from,map_location=torch.device('cpu'))
# self.resume(resume_from)
self._has_loaded = True
elif self._load_from is not None:
## add ,map_location=torch.device('cpu')
self.load_checkpoint(self._load_from, map_location=torch.device('cpu'))
# self.load_checkpoint(self._load_from)
self._has_loaded = True
4. 结果
如图所示,问题解决了。