Windows 系统下,使用 PyTorch 的 DataLoader 时,如果 num_workers 参数设置为大于 0 的值,报错

发布于:2025-02-23 ⋅ 阅读:(19) ⋅ 点赞:(0)

在 Windows 系统下,使用 PyTorch 的 DataLoader 时,如果 num_workers 参数设置为大于 0 的值,可能会遇到以下错误:

RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

原因分析

这个错误是由于 Windows 系统不支持 fork 方式启动子进程,而 PyTorch 的 DataLoader 在多线程情况下默认使用 fork。因此,当 num_workers 大于 0 时,会触发这个错误。

解决方案

  1. num_workers 设置为 0 在 Windows 系统下,建议将 num_workers 设置为 0,这样 DataLoader 将不会使用额外的工作进程来加载数据,从而避免上述错误。代码如下:

    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
  2. 使用 spawnforkserver 启动方式 如果需要使用多线程加载数据,可以指定 multiprocessing 的启动方式为 spawnforkserver。在代码的开头添加以下内容:

    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)

    然后再设置 num_workers 为大于 0 的值:

    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)
  3. 确保 if __name__ == '__main__': 保护 确保主程序入口被 if __name__ == '__main__': 保护,这样可以避免多进程启动时的冲突。示例如下:

  4. if __name__ == '__main__':
        import multiprocessing
        multiprocessing.set_start_method('spawn', force=True)
    
        # Your main code here
        dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

推荐解决方案

在 Windows 系统下,最简单的解决方案是将 num_workers 设置为 0。如果需要使用多线程加载数据,可以尝试指定 multiprocessing 的启动方式为 spawnforkserver,并确保主程序入口被 if __name__ == '__main__': 保护。