pytorch 设置参数
前言
深度学习的pytorch框架学习,有错误的地方请大家批评指正
一、openAI的官方代码
def create_argparser():
defaults = dict(
data_dir="",
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=10,
save_interval=10000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
二、解析
1.使用字典,简化添加参数过程
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
2. add_dict_to_argparser函数
代码如下(示例):
def add_dict_to_argparser(parser, default_dict):
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{k}", default=v, type=v_type)
2. add_dict_to_argparser函数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=128, type=int,help='Size of a training mini-batch.')
总结
使用新办法可以快速添加参数,很方便,而且看起来很简洁美观
本文含有隐藏内容,请 开通VIP 后查看