在上一篇文章中,我们探讨了联邦学习与隐私保护技术。本文将深入介绍神经架构搜索(Neural Architecture Search, NAS)这一自动化机器学习方法,它能够自动设计高性能的神经网络架构。我们将使用PyTorch实现基于梯度优化的DARTS方法,并在CIFAR-10数据集上进行验证。
一、神经架构搜索基础
神经架构搜索是AutoML的核心技术之一,旨在自动化神经网络设计过程。
1. NAS的核心组件
组件 | 描述 | 典型实现 |
---|---|---|
搜索空间 | 定义可能架构的集合 | 细胞结构、宏架构 |
搜索策略 | 探索搜索空间的方法 | 强化学习、进化算法、梯度优化 |
性能评估 | 评估架构质量的方式 | 代理指标、权重共享 |
2. 主流NAS方法对比
class NASMethod(Enum):
RL_BASED = "基于强化学习" # Google早期方案
EVOLUTIONARY = "进化算法" # Google Brain提出
GRADIENT_BASED = "梯度优化" # DARTS为代表
ONESHOT = "权重共享" # ENAS、ProxylessNAS
3. DARTS数学原理
DARTS(Differentiable ARchiTecture Search)将离散架构搜索转化为连续优化问题:
二、DARTS实战:CIFAR-10图像分类
1. 环境配置
pip install torch torchvision matplotlib graphviz
2. 实现可微分架构搜索
2.1 搜索空间定义
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from matplotlib import pyplot as plt
import copy
from graphviz import Digraph
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 操作类型集合
OPS = {
'none': lambda C, stride: Zero(stride),
'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C),
'conv_3x3': lambda C, stride: ConvBNReLU(C, C, 3, stride, 1),
'conv_5x5': lambda C, stride: ConvBNReLU(C, C, 5, stride, 2),
'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 2, 2),
'dil_conv_5x5': lambda C, stride: DilConv(C, C, 5, stride, 4, 2),
'max_pool_3x3': lambda C, stride: PoolBN('max', C, 3, stride, 1),
'avg_pool_3x3': lambda C, stride: PoolBN('avg', C, 3, stride, 1)
}
# 基础操作模块
class ConvBNReLU(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super().__init__()
self.op = nn.Sequential(
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out),
nn.ReLU(inplace=False)
)
def forward(self, x):
return self.op(x)
class DilConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super().__init__()
self.op = nn.Sequential(
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, 1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
nn.ReLU(inplace=False)
)
def forward(self, x):
return self.op(x)
class PoolBN(nn.Module):
def __init__(self, pool_type, C, kernel_size, stride, padding):
super().__init__()
if pool_type == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C)
def forward(self, x):
return self.bn(self.pool(x))
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super().__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
return self.bn(torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1))
2.2 可微分细胞结构实现
class MixedOp(nn.Module):
"""混合操作实现"""
def __init__(self, C, stride):
super().__init__()
self._ops = nn.ModuleList()
for primitive in OPS.keys():
op = OPS[primitive](C, stride)
self._ops.append(op)
def forward(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops))
class Cell(nn.Module):
"""可微分细胞结构"""
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
super().__init__()
self.reduction = reduction
self.steps = steps
self.multiplier = multiplier
# 预处理节点
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
else:
self.preprocess0 = ConvBNReLU(C_prev_prev, C, 1, 1, 0)
self.preprocess1 = ConvBNReLU(C_prev, C, 1, 1, 0)
# 构建DAG结构
self._ops = nn.ModuleList()
self._bns = nn.ModuleList()
for i in range(self.steps):
for j in range(2 + i):
stride = 2 if reduction and j < 2 else 1
op = MixedOp(C, stride)
self._ops.append(op)
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self.steps):
s = sum(self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states))
offset += len(states)
states.append(s)
return torch.cat(states[-self.multiplier:], dim=1)
2.3 完整搜索网络
class Network(nn.Module):
"""可微分架构搜索网络"""
def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3):
super().__init__()
self._C = C
self._num_classes = num_classes
self._layers = layers
self._criterion = criterion
self._steps = steps
self._multiplier = multiplier
C_curr = stem_multiplier * C
self.stem = nn.Sequential(
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
nn.BatchNorm2d(C_curr)
)
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
self.cells = nn.ModuleList()
reduction_prev = False
for i in range(layers):
if i in [layers // 3, 2 * layers // 3]:
C_curr *= 2
reduction = True
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
reduction_prev = reduction
self.cells.append(cell)
C_prev_prev, C_prev = C_prev, multiplier * C_curr
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
# 架构参数
k = sum(2 + i for i in range(steps))
num_ops = len(OPS)
self._alphas = nn.Parameter(1e-3 * torch.randn(k, num_ops)) # 使用随机初始化
# 修正优化器初始化
self._arch_optimizer = torch.optim.Adam([self._alphas], lr=6e-4, betas=(0.5, 0.999))
def forward(self, x):
s0 = s1 = self.stem(x)
weights = F.softmax(self._alphas, dim=-1)
for cell in self.cells:
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
return logits
def _loss(self, input, target):
logits = self(input)
# 添加L1正则化
reg_loss = 0.01 * torch.sum(torch.exp(-self._alphas))
return self._criterion(logits, target) + reg_loss
def arch_parameters(self):
return [self._alphas]
def genotype(self):
"""从架构参数导出离散架构"""
def _parse(weights):
gene = []
start = 0
for i in range(self._steps):
end = start + i + 2
W = weights[start:end].copy()
edges = []
for j in range(2 + i):
k_best = None
for k in range(len(W[j])):
if k_best is None or W[j][k] > W[j][k_best]:
k_best = k
edges.append((list(OPS.keys())[k_best], j)) # 修正OPS.keys()索引
gene.append(edges)
start = end
return gene
gene_normal = _parse(F.softmax(self._alphas, dim=-1).data.cpu().numpy())
return gene_normal
def plot_genotype(self, filename):
"""可视化基因型"""
dot = Digraph(format='png')
for i, edges in enumerate(self.genotype()):
for op, j in edges:
dot.edge(str(j), str(i + 2), label=op)
dot.node("0", fillcolor='lightblue', style='filled')
dot.node("1", fillcolor='lightblue', style='filled')
dot.render(filename, view=True)
3. 搜索算法实现
class DARTS:
def __init__(self, model, train_loader, val_loader, epochs=50):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.epochs = epochs
# 优化器
self.optimizer = torch.optim.SGD(
model.parameters(), lr=0.025, momentum=0.9, weight_decay=3e-4)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, epochs, eta_min=0.001)
# 架构搜索参数
self.arch_optimizer = torch.optim.Adam(
model.arch_parameters(), lr=3e-4, betas=(0.5, 0.999))
def _train(self):
self.model.train()
train_loss = 0
correct = 0
total = 0
for inputs, targets in self.train_loader:
inputs, targets = inputs.to(device), targets.to(device)
# 更新架构参数
self.arch_optimizer.zero_grad()
arch_loss = self.model._loss(inputs, targets)
arch_loss.backward()
self.arch_optimizer.step()
# 更新模型权重
self.optimizer.zero_grad()
loss = self.model._loss(inputs, targets)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
self.optimizer.step()
train_loss += loss.item()
_, predicted = self.model(inputs).max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return train_loss / len(self.train_loader), 100. * correct / total
def _validate(self):
self.model.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in self.val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = self.model(inputs)
loss = self.model._loss(inputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return val_loss / len(self.val_loader), 100. * correct / total
def search(self):
best_acc = 0
history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
for epoch in range(self.epochs):
train_loss, train_acc = self._train()
val_loss, val_acc = self._validate()
self.scheduler.step()
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['train_acc'].append(train_acc)
history['val_acc'].append(val_acc)
if val_acc > best_acc:
best_acc = val_acc
best_genotype = copy.deepcopy(self.model.genotype())
print(f"Epoch: {epoch + 1}/{self.epochs} | "
f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
f"Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
return best_genotype, history
4. 完整训练流程
# 数据准备
def prepare_data(batch_size=64, val_ratio=0.1):
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_size = int(val_ratio * len(full_dataset))
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, val_loader
# 主函数
def main():
train_loader, val_loader = prepare_data()
# 初始化模型
criterion = nn.CrossEntropyLoss().to(device)
model = Network(C=16, num_classes=10, layers=8, criterion=criterion)
# 开始搜索
darts = DARTS(model, train_loader, val_loader, epochs=50)
best_genotype, history = darts.search()
# 保存结果
print("Best Genotype:", best_genotype)
model.plot_genotype("best_architecture")
# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Validation')
plt.title('Loss Curve')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train')
plt.plot(history['val_acc'], label='Validation')
plt.title('Accuracy Curve')
plt.legend()
plt.savefig('search_progress.png')
plt.show()
if __name__ == "__main__":
main()
输出为:
使用设备: cuda
Files already downloaded and verified
Epoch: 1/50 | Train Loss: 2.3124 | Val Loss: 2.1910 | Train Acc: 46.93% | Val Acc: 55.88%
Epoch: 2/50 | Train Loss: 1.3163 | Val Loss: 1.1209 | Train Acc: 67.39% | Val Acc: 68.60%
Epoch: 3/50 | Train Loss: 1.0027 | Val Loss: 0.8909 | Train Acc: 75.54% | Val Acc: 75.54%
Epoch: 4/50 | Train Loss: 0.8162 | Val Loss: 0.7608 | Train Acc: 80.52% | Val Acc: 79.12%
Epoch: 5/50 | Train Loss: 0.7098 | Val Loss: 0.7313 | Train Acc: 83.37% | Val Acc: 79.44%
Epoch: 6/50 | Train Loss: 0.6268 | Val Loss: 0.6517 | Train Acc: 85.73% | Val Acc: 82.10%
Epoch: 7/50 | Train Loss: 0.5662 | Val Loss: 0.6084 | Train Acc: 87.38% | Val Acc: 83.04%
Epoch: 8/50 | Train Loss: 0.5164 | Val Loss: 0.5669 | Train Acc: 88.96% | Val Acc: 84.56%
Epoch: 9/50 | Train Loss: 0.4790 | Val Loss: 0.5206 | Train Acc: 90.14% | Val Acc: 85.66%
Epoch: 10/50 | Train Loss: 0.4447 | Val Loss: 0.5097 | Train Acc: 91.23% | Val Acc: 85.64%
Epoch: 11/50 | Train Loss: 0.4135 | Val Loss: 0.5081 | Train Acc: 92.16% | Val Acc: 85.78%
Epoch: 12/50 | Train Loss: 0.3887 | Val Loss: 0.5135 | Train Acc: 92.81% | Val Acc: 85.76%
Epoch: 13/50 | Train Loss: 0.3687 | Val Loss: 0.4952 | Train Acc: 93.40% | Val Acc: 86.02%
Epoch: 14/50 | Train Loss: 0.3490 | Val Loss: 0.4915 | Train Acc: 94.02% | Val Acc: 86.72%
Epoch: 15/50 | Train Loss: 0.3323 | Val Loss: 0.5027 | Train Acc: 94.69% | Val Acc: 86.20%
Epoch: 16/50 | Train Loss: 0.3109 | Val Loss: 0.4722 | Train Acc: 95.34% | Val Acc: 87.44%
Epoch: 17/50 | Train Loss: 0.2952 | Val Loss: 0.4687 | Train Acc: 95.74% | Val Acc: 87.14%
Epoch: 18/50 | Train Loss: 0.2780 | Val Loss: 0.4605 | Train Acc: 96.38% | Val Acc: 87.92%
Epoch: 19/50 | Train Loss: 0.2591 | Val Loss: 0.4469 | Train Acc: 96.82% | Val Acc: 88.26%
Epoch: 20/50 | Train Loss: 0.2474 | Val Loss: 0.4479 | Train Acc: 97.22% | Val Acc: 88.04%
Epoch: 21/50 | Train Loss: 0.2371 | Val Loss: 0.4765 | Train Acc: 97.46% | Val Acc: 87.90%
Epoch: 22/50 | Train Loss: 0.2257 | Val Loss: 0.4213 | Train Acc: 97.78% | Val Acc: 89.00%
Epoch: 23/50 | Train Loss: 0.2100 | Val Loss: 0.4625 | Train Acc: 98.21% | Val Acc: 88.38%
Epoch: 24/50 | Train Loss: 0.2045 | Val Loss: 0.4474 | Train Acc: 98.24% | Val Acc: 88.74%
Epoch: 25/50 | Train Loss: 0.1859 | Val Loss: 0.4511 | Train Acc: 98.60% | Val Acc: 88.48%
Epoch: 26/50 | Train Loss: 0.1790 | Val Loss: 0.4307 | Train Acc: 98.81% | Val Acc: 89.54%
Epoch: 27/50 | Train Loss: 0.1644 | Val Loss: 0.4390 | Train Acc: 99.08% | Val Acc: 89.80%
Epoch: 28/50 | Train Loss: 0.1541 | Val Loss: 0.4344 | Train Acc: 99.17% | Val Acc: 89.60%
Epoch: 29/50 | Train Loss: 0.1449 | Val Loss: 0.4176 | Train Acc: 99.32% | Val Acc: 90.34%
Epoch: 30/50 | Train Loss: 0.1352 | Val Loss: 0.3915 | Train Acc: 99.47% | Val Acc: 90.64%
Epoch: 31/50 | Train Loss: 0.1261 | Val Loss: 0.4300 | Train Acc: 99.58% | Val Acc: 90.20%
Epoch: 32/50 | Train Loss: 0.1183 | Val Loss: 0.3936 | Train Acc: 99.67% | Val Acc: 91.10%
Epoch: 33/50 | Train Loss: 0.1056 | Val Loss: 0.3889 | Train Acc: 99.77% | Val Acc: 91.00%
Epoch: 34/50 | Train Loss: 0.0990 | Val Loss: 0.3937 | Train Acc: 99.81% | Val Acc: 91.00%
Epoch: 35/50 | Train Loss: 0.0949 | Val Loss: 0.3694 | Train Acc: 99.77% | Val Acc: 92.16%
Epoch: 36/50 | Train Loss: 0.0862 | Val Loss: 0.3788 | Train Acc: 99.89% | Val Acc: 91.72%
Epoch: 37/50 | Train Loss: 0.0815 | Val Loss: 0.3893 | Train Acc: 99.90% | Val Acc: 91.52%
Epoch: 38/50 | Train Loss: 0.0768 | Val Loss: 0.3847 | Train Acc: 99.92% | Val Acc: 91.92%
Epoch: 39/50 | Train Loss: 0.0729 | Val Loss: 0.3602 | Train Acc: 99.95% | Val Acc: 91.90%
Epoch: 40/50 | Train Loss: 0.0689 | Val Loss: 0.3846 | Train Acc: 99.94% | Val Acc: 91.68%
Epoch: 41/50 | Train Loss: 0.0656 | Val Loss: 0.3361 | Train Acc: 99.95% | Val Acc: 92.62%
Epoch: 42/50 | Train Loss: 0.0625 | Val Loss: 0.3563 | Train Acc: 99.96% | Val Acc: 92.18%
Epoch: 43/50 | Train Loss: 0.0598 | Val Loss: 0.3475 | Train Acc: 99.96% | Val Acc: 92.28%
Epoch: 44/50 | Train Loss: 0.0579 | Val Loss: 0.3468 | Train Acc: 99.94% | Val Acc: 92.22%
Epoch: 45/50 | Train Loss: 0.0561 | Val Loss: 0.3680 | Train Acc: 99.94% | Val Acc: 91.64%
Epoch: 46/50 | Train Loss: 0.0532 | Val Loss: 0.3334 | Train Acc: 99.95% | Val Acc: 92.40%
Epoch: 47/50 | Train Loss: 0.0509 | Val Loss: 0.3381 | Train Acc: 99.96% | Val Acc: 92.50%
Epoch: 48/50 | Train Loss: 0.0493 | Val Loss: 0.3517 | Train Acc: 99.95% | Val Acc: 92.16%
Epoch: 49/50 | Train Loss: 0.0474 | Val Loss: 0.3305 | Train Acc: 99.95% | Val Acc: 92.30%
Epoch: 50/50 | Train Loss: 0.0458 | Val Loss: 0.3305 | Train Acc: 99.94% | Val Acc: 92.84%
Best Genotype: [[('conv_5x5', 0), ('conv_5x5', 1)], [('none', 0), ('conv_5x5', 1), ('conv_5x5', 2)], [('conv_5x5', 0), ('conv_5x5', 1), ('conv_5x5', 2), ('conv_5x5', 3)], [('conv_5x5', 0), ('conv_5x5', 1), ('conv_5x5', 2), ('conv_5x5', 3), ('conv_5x5', 4)]]
perl: warning: Setting locale failed.
perl: warning: Please check that your locale settings:
LANGUAGE = (unset),
LC_ALL = (unset),
LC_CTYPE = "C.UTF-8",
LANG = "en_US.UTF-8"
are supported and installed on your system.
perl: warning: Falling back to the standard locale ("C").
Error: no "view" mailcap rules found for type "image/png"
/usr/bin/xdg-open: 869: www-browser: not found
/usr/bin/xdg-open: 869: links2: not found
/usr/bin/xdg-open: 869: elinks: not found
/usr/bin/xdg-open: 869: links: not found
/usr/bin/xdg-open: 869: lynx: not found
/usr/bin/xdg-open: 869: w3m: not found
xdg-open: no method available for opening 'best_architecture.png'
错误提示 Error: no "view" mailcap rules found for type "image/png" 是因为系统缺少图片查看工具(如浏览器)
不影响结果:文件 best_architecture.png 仍会生成,但无法自动弹出预览。
三、进阶话题
1. 搜索空间设计技巧
class MacroSearchSpace:
"""宏架构搜索空间示例"""
def __init__(self):
self.resolutions = [224, 192, 160, 128] # 输入分辨率
self.depths = [3, 4, 5, 6] # 网络深度
self.widths = [32, 64, 96, 128] # 初始通道数
self.ops = OPS.keys() # 操作类型
2. 多目标NAS实现
class MultiObjectiveNAS:
"""同时优化精度和延迟"""
def __init__(self, model, latency_predictor):
self.model = model
self.latency_predictor = latency_predictor
def evaluate(self, genotype):
# 预测延迟
latency = self.latency_predictor(genotype)
# 评估精度
accuracy = evaluate_accuracy(genotype)
return {
'accuracy': accuracy,
'latency': latency,
'score': accuracy * (latency ** -0.07) # 平衡因子
}
3. 实际应用建议
场景 | 推荐方法 | 理由 |
---|---|---|
移动端部署 | ProxylessNAS | 直接优化目标设备指标 |
研究探索 | DARTS | 灵活可扩展 |
工业级应用 | ENAS | 搜索效率高 |
四、总结与展望
本文实现了基于DARTS的神经架构搜索系统,主要亮点包括:
完整实现了可微分架构搜索:包括混合操作、细胞结构和双层优化
可视化搜索过程:支持架构基因型的图形化展示
实用训练技巧:采用余弦退火学习率等优化策略
在下一篇文章中,我们将探讨图生成模型与分子设计,介绍如何利用深度生成模型设计新型分子结构。