为什么我们需要if __name__ == __main__:

发布于:2025-04-15 ⋅ 阅读:(33) ⋅ 点赞:(0)
[目录]
0.前言
1.什么是 `__name__`?
2.`if __name__ == '__main__'`: 的作用
3.为何Windows更需`if __name__ =`?

前言

if __name__ == '__main__': 是 Python 中一个非常重要的惯用法,尤其在使用 multiprocessing 模块或编写可导入的模块时。它的作用是区分脚本是直接运行还是被导入,从而控制代码的执行行为。很多初学者可能对此感到困惑,不明白其真正的用途和重要性。

直到发现自己的CPU因为不良的代码习惯而干烧了数十分钟才知道后悔

下面详细解释它的作用和工作原理。


1. 什么是 __name__

__name__ 是 Python 中的一个内置变量,它的值取决于脚本的运行方式:

  • 当脚本被直接运行时(例如通过 python script.py 运行):

    • __name__ 的值是 '__main__'
    • 这是 Python 解释器自动设置的,表示当前脚本是“主脚本”。
  • 当脚本被导入为模块时(例如 import script):

    • __name__ 的值是模块的名称(例如 script)。
    • 此时,脚本中的代码会被执行,但 __name__ 不再是 '__main__'

示例

假设有一个脚本 example.py

print("The value of __name__ is:", __name__)

if __name__ == '__main__':
    print("This script is being run directly.")
else:
    print("This script is being imported as a module.")
  • 直接运行

    python example.py
    

    输出:

    The value of __name__ is: __main__
    This script is being run directly.
    
  • 作为模块导入: 创建另一个脚本 importer.py

    import example
    

    没错,这个新脚本就这么短小精悍。

    输出:

    The value of __name__ is: example
    This script is being imported as a module.
    

2.if __name__ == '__main__': 的作用

if __name__ == '__main__': 的作用是让某些代码块只在脚本被直接运行时执行,而在脚本被导入时不执行。这有以下几个重要用途:

避免导入时的副作用

当一个 Python 脚本被导入为模块时,脚本中的所有顶层代码(不在函数或类中的代码)都会被执行。
如果这些顶层代码包含不希望在导入时运行的操作(例如启动服务器、执行复杂计算、创建进程等),会产生意外行为。
使用 if name == ‘main’:,可以确保这些代码只在脚本被直接运行时执行。

示例

假设有一个脚本 math_utils.py:

# 顶层代码
print("This will always run when the script is imported!")

def add(a, b):
    return a + b

# 不使用 if __name__ == '__main__':
result = add(2, 3)
print(f"Result of add(2, 3): {result}")

另一个脚本 main.py导入了math_utils:

import math_utils

print("Using math_utils to add numbers...")
print(math_utils.add(5, 6))

运行main.py:

This will always run when the script is imported!
Result of add(2, 3): 5
Using math_utils to add numbers...
11

问题在于,math_utils.py 中的顶层代码(printresult = add(2, 3))在导入时被执行了,这可能不是我们想要的。
现在使用 if __name__ == '__main__': 修改 math_utils.py:

print("This will always run when the script is imported!")

def add(a, b):
    return a + b

if __name__ == '__main__':
    result = add(2, 3)
    print(f"Result of add(2, 3): {result}")

这时我们得到的运行后结果为:

This will always run when the script is imported!
Using math_utils to add numbers...
11

我们可以发现,if __name__ == '__main__': 块中的代码(result = add(2, 3) 和相关的 print)只在 math_utils.py 被直接运行时执行,导入时不会运行。

其他用途

(1) 测试代码

你可以在 if name == ‘main’: 中添加测试代码,这些代码只在脚本直接运行时执行,而不会在导入时运行。譬如:

def add(a, b):
    return a + b

if __name__ == '__main__':
    # 测试代码
    print(add(2, 3))
    print(add(5, 6))

直接运行时,测试代码会执行;而导入时,测试代码则不会运行。

(2) 命令行工具

许多命令行工具使用 if __name__ == '__main__': 来定义入口点,确保主逻辑只在直接运行时执行。譬如:

import sys

def main():
    print("Hello, world!")
    print("Arguments:", sys.argv)

if __name__ == '__main__':
    main()

如果不使用 if __name__ == '__main__': 保护,主逻辑会在脚本被导入时意外执行,这可能导致不希望的行为,尤其是在命令行工具中。

3. 为何Windows更需if __name__ =

这主要取决于我们 coding \texttt{coding} coding的场景——是否会用到multiprocessing 或是其它类似方法来加速我们的计算。

在 Linux 上,multiprocessing 默认使用 fork 方法创建子进程。fork 会直接复制主进程的内存状态,子进程不会重新加载脚本,因此顶层代码不会被重复执行。

在 Windows 上,multiprocessing 使用 spawn 方法,必须重新加载脚本,导致顶层代码被重复执行,因此需要 if __name__ == '__main__': 保护。

值得提醒的是,当我们在 Github \texttt{Github} Github上拿到其它大佬的项目代码时,我们自己在本地运行时一定要检查是否存在类似问题。以pix2pix与CycleGAN项目为例,其原始代码为:

import os
import numpy as np
import cv2
import argparse
from multiprocessing import Pool


def image_write(path_A, path_B, path_AB):
  im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
  im_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
  im_AB = np.concatenate([im_A, im_B], 1)
  cv2.imwrite(path_AB, im_AB)


parser = argparse.ArgumentParser('create image pairs')
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000)
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true')
parser.add_argument('--no_multiprocessing', dest='no_multiprocessing', help='If used, chooses single CPU execution instead of parallel execution', action='store_true',default=False)
args = parser.parse_args()

for arg in vars(args):
  print('[%s] = ' % arg, getattr(args, arg))

splits = os.listdir(args.fold_A)

if not args.no_multiprocessing:
  pool=Pool()

for sp in splits:
  img_fold_A = os.path.join(args.fold_A, sp)
  img_fold_B = os.path.join(args.fold_B, sp)
  img_list = os.listdir(img_fold_A)
  if args.use_AB:
      img_list = [img_path for img_path in img_list if '_A.' in img_path]

  num_imgs = min(args.num_imgs, len(img_list))
  print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
  img_fold_AB = os.path.join(args.fold_AB, sp)
  if not os.path.isdir(img_fold_AB):
      os.makedirs(img_fold_AB)
  print('split = %s, number of images = %d' % (sp, num_imgs))
  for n in range(num_imgs):
      name_A = img_list[n]
      path_A = os.path.join(img_fold_A, name_A)
      if args.use_AB:
          name_B = name_A.replace('_A.', '_B.')
      else:
          name_B = name_A
      path_B = os.path.join(img_fold_B, name_B)
      if os.path.isfile(path_A) and os.path.isfile(path_B):
          name_AB = name_A
          if args.use_AB:
              name_AB = name_AB.replace('_A.', '.')  # remove _A
          path_AB = os.path.join(img_fold_AB, name_AB)
          if not args.no_multiprocessing:
              pool.apply_async(image_write, args=(path_A, path_B, path_AB))
          else:
              im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
              im_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
              im_AB = np.concatenate([im_A, im_B], 1)
              cv2.imwrite(path_AB, im_AB)
if not args.no_multiprocessing:
  pool.close()
  pool.join()

即是显然是在 Linux \texttt{Linux} Linux系统上使用的multiprocessing方法。本人一开始尚未注意到该问题,结果CPU干烧了十几分钟,出现类似的 RuntimeError \texttt{RuntimeError} RuntimeError

将代码修正后:

import os
import numpy as np
import cv2
import argparse
from multiprocessing import Pool

def image_write(path_A, path_B, path_AB):
    im_A = cv2.imread(path_A, 1)
    im_B = cv2.imread(path_B, 1)
    if im_A is None or im_B is None:
        print(f"Failed to load images: {path_A} or {path_B}")
        return
    im_AB = np.concatenate([im_A, im_B], 1)
    cv2.imwrite(path_AB, im_AB)

if __name__ == '__main__':
    parser = argparse.ArgumentParser('create image pairs')
    parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
    parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
    parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
    parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000)
    parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true')
    parser.add_argument('--no_multiprocessing', dest='no_multiprocessing', help='If used, chooses single CPU execution instead of parallel execution', action='store_true', default=False)
    args = parser.parse_args()

    for arg in vars(args):
        print('[%s] = ' % arg, getattr(args, arg))

    splits = os.listdir(args.fold_A)

    if not args.no_multiprocessing:
        pool = Pool()

    for sp in splits:
        img_fold_A = os.path.join(args.fold_A, sp)
        img_fold_B = os.path.join(args.fold_B, sp)
        img_list = os.listdir(img_fold_A)
        if args.use_AB:
            img_list = [img_path for img_path in img_list if '_A.' in img_path]

        num_imgs = min(args.num_imgs, len(img_list))
        print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
        img_fold_AB = os.path.join(args.fold_AB, sp)
        if not os.path.isdir(img_fold_AB):
            os.makedirs(img_fold_AB)
        print('split = %s, number of images = %d' % (sp, num_imgs))
        for n in range(num_imgs):
            name_A = img_list[n]
            path_A = os.path.join(img_fold_A, name_A)
            if args.use_AB:
                name_B = name_A.replace('_A.', '_B.')
            else:
                name_B = name_A
            path_B = os.path.join(img_fold_B, name_B)
            if os.path.isfile(path_A) and os.path.isfile(path_B):
                print(f"Found pair: {path_A} and {path_B}")
                name_AB = name_A
                if args.use_AB:
                    name_AB = name_AB.replace('_A.', '.')  # remove _A
                path_AB = os.path.join(img_fold_AB, name_AB)
                if not args.no_multiprocessing:
                    pool.apply_async(image_write, args=(path_A, path_B, path_AB))
                else:
                    image_write(path_A, path_B, path_AB)
            else:
                print(f"Pair not found: {path_A} or {path_B}")

    if not args.no_multiprocessing:
        pool.close()
        pool.join()

终于能够正常运行。