利用自适应双向对比重建网络与精细通道注意机制实现图像去雾化技术的PyTorch代码解析
漫谈图像去雾化的挑战
在计算机视觉领域,图像复原一直是研究热点。其中,图像去雾化技术尤其具有实际应用价值。然而,复杂的气象条件和多种因素干扰使得这一任务充满挑战。
传统的去雾方法往往难以在保留细节的同时移除雾气。深度学习的兴起为该问题带来了新的解决思路,但现有模型在处理不同光照条件、不同层次雾霾时仍显不足。
论文概述
这篇论文提出了一个创新性的网络结构:Unsupervised Bidirectional Contrastive Reconstruction Network (UB-CRN),结合了自适应Fine-Grained Channel Attention(FCA)机制。该方法在无监督学习框架下,通过双向对比重建和通道注意力调整,有效提升了去雾效果。
核心模块解析
Mix混合模块
代码中定义了一个
Mix
类,用于融合特征通道的信息:class Mix(nn.Module): def __init__(self, m=-0.80): super(Mix, self).__init__() w = torch.nn.Parameter(torch.FloatTensor([m]), requires_grad=True) w = torch.nn.Parameter(w, requires_grad=True) self.w = w self.mix_block = nn.Sigmoid() def forward(self, fea1, fea2): mix_factor = self.mix_block(self.w) out = fea1 * mix_factor.expand_as(fea1) + fea2 * (1 - mix_factor.expand_as(fea2)) return out
这个模块的作用是将两个特征图进行加权融合,使得网络能够自适应地决定每个通道的信息贡献程度。
自适应精细通道注意(FCA)
class FCAttention(nn.Module): def __init__(self, channel, b=1, gamma=2): super(FCAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) t = int(abs((math.log(channel, 2) + b) / gamma)) k = t if t % 2 else t + 1 self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False) self.fc = nn.Conv2d(channel, channel, 1, padding=0, bias=True) self.sigmoid = nn.Sigmoid() self.mix = Mix() def forward(self, x): b, c, _, _ = x.size() x_avg = self.avg_pool(x).view(b, c) x1 = self.conv1(x_avg.unsqueeze(2)).squeeze(-2) x2 = self.fc(x_avg.unsqueeze(-1).unsqueeze(-1)) out1 = torch.sum(torch.matmul(x1, x2), dim=1, keepdim=True) out1 = self.sigmoid(out1) out2 = torch.sum(torch.matmul(x2.transpose(-2, -1), x1.transpose(-2, -1)), dim=-2, keepdim=True) out2 = self.sigmoid(out2) out = self.mix(out1.permute(0, 2, 1).contiguous(), out2.permute(0, 2, 1).contiguous()) out = self.conv1(out.squeeze(-2).unsqueeze(-3)).permute(0, 1, -1) return x * out
这个模块通过自适应地计算通道间的依赖关系,赋予每个通道不同的权重,从而提升网络对特征的捕捉能力。
完整模型实现
以下是整个图像去雾化网络的代码框架:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
# 实现Mix模块
class Mix(nn.Module):
def __init__(self, m=-0.80):
super(Mix, self).__init__()
w = torch.nn.Parameter(torch.FloatTensor([m]), requires_grad=True)
self.w = nn.Parameter(w)
self.mix_block = nn.Sigmoid()
def forward(self, fea1, fea2):
mix_factor = self.mix_block(self.w)
out = fea1 * mix_factor.expand_as(fea1) + fea2 * (1 - mix_factor.expand_as(fea2))
return out
# 实现自适应Fine-Grained Channel Attention
class FCAttention(nn.Module):
def __init__(self, channel, b=1, gamma=2):
super(FCAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
t = int(abs((math.log(channel, 2) + b) / gamma))
k = t if t % 2 else t + 1
self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k/2), bias=False)
self.fc_layer = nn.Conv2d(channel, channel, kernel_size=1, padding=0, bias=True)
self.sigmoid = nn.Sigmoid()
self.mix = Mix()
def forward(self, x):
b, c, h, w = x.size()
features = x.view(b * c, -1)
# 特征聚合
avg_features = self.avg_pool(x).view(b, c)
avg_features.unsqueeze_(-1) # (b, c, 1)
# 使用一维卷积处理每个特征通道
x1 = self.conv1(avg_features.transpose(1,2)) # (b, 1, c)
x1 = x1.squeeze(dim=-2).transpose(0,1) # 调整维度,变为(c, b)
# 全连接层处理
x2 = self.fc_layer(avg_features.unsqueeze(-1).unsqueeze(-1))
x2 = x2.view(b, c) # (b, c)
# 计算通道间的注意力权重
out1 = torch.matmul(x1, x2.permute(0,1).contiguous())
out1 = self.sigmoid(out1.unsqueeze(-1).unsqueeze(-1))
out2 = torch.matmul(x2.permute(0,1), x1.contiguous())
out2 = self.sigmoid(out2.unsqueeze(-1).unsqueeze(-1))
# Mix模块融合
out3 = self.mix(out1.permute(0, 2, 1), out2.permute(0, 2, 1))
# 将混合特征调整形状并应用sigmoid函数
attention = F.conv1d(out3.unsqueeze(-1), self.conv1.weight.data[:, 0:1, :])
attention = attention.squeeze().transpose(1, 0).contiguous()
return x * (attention.view(b, c, 1, 1) )
# 初始化网络
def init_weights(m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.01)
if m.bias is not None:
m.bias.data.fill_(0.0)
# 定义整个图像去雾化网络
class Dehazer(nn.Module):
def __init__(self, feature_size=3):
super(Dehazer, self).__init__()
# 初始特征提取模块,这里简要处理,实际可根据需要增加其他结构
self.conv_init = nn.Conv2d(feature_size, 64, kernel_size=3, padding=1)
# 多尺度特征融合
self.attention = FCAttention(64, b=0.8, gamma=2)
# 后处理模块,可增加其他恢复网络结构如DRC等
self.conv_final = nn.Conv2d(64, 3, kernel_size=1, padding=0)
def forward(self, x):
x1 = F.relu(self.conv_init(x))
out = self.attention(x1)
out = (torch.tanh(self.conv_final(out))) + 1
return out * x
# 加载模型并初始化参数
model = Dehazer(3).cuda()
model.apply(init_weights)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
# 训练数据准备和训练循环
# 这里的Dataset和DataLoader需要根据实际数据集定制
...
# 每个epoch进行迭代
for epoch in range(20):
for batch, (inputs, targets) in enumerate(train_loader):
inputs = inputs.cuda()
targets = targets.cuda()
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 模型保存
torch.save(model.state_dict(), 'dehazer.pth')
总结
通过上述代码,我们实现了一个基于自适应Fine-Grained通道注意力机制的图像去雾化模型。该模型能够自动学习每个通道的重要性,并相应地调整其权重,从而在不同光照条件下有效恢复清晰的图像。
关键步骤解释:
- 特征提取:使用卷积层从输入图像中提取初始特征。
- 通道注意力计算:通过自适应平均池化和一维卷积等操作,计算通道间的依赖关系,并得到每个通道的权重。
- 特征融合:利用Mix模块将原始特征与计算得到的注意力权重进行融合,生成最终的增强特征图。
- 目标图像恢复:基于融合后的特征图,通过反卷积或调整因子的方式得到去雾化后的输出。
这种方法克服了传统图像去雾化方法中通道间相互干扰的问题,能够更精准地恢复被雾霾遮蔽的细节信息。