前言
本文从原理和代码介绍低照度增强领域中比较新的一篇论文——Retinexformer,其效果不错,刷新了十三大暗光增强效果榜单。
论文名称:Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement
论文地址:[2303.06705] Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement[2303.06705] Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement
代码地址:https://github.com/caiyuanhao1998/Retinexformer
主要贡献
1. 首次将 Retinex 理论引入 Transformer 框架
传统的 Retinex 理论将图像建模为 反射分量 × 光照分量(Reflectance × Illumination);
RetinexFormer 将这一思想模块化为结构化网络,包括光照估计、去噪和重建模块;
构建了Stage-wise 分层架构,分阶段提取和增强光照信息。
2. 提出 Illumination Estimation 模块 + Transformer-based IGAB(Information-Guided Aggregation Block)
Illumination Estimator 提取低光图像的光照图;
IGAB 利用注意力机制自适应融合局部与全局信息,引导去噪与细节重建;
提高了模型对复杂光照分布的建模能力,尤其在极暗场景下仍能保持结构细节。
3. 提出 Multi-Stage Design 以逐步增强图像
利用多阶段策略:每一阶段都包含光照估计、增强和重建过程;
上一阶段的特征与图像结果被输入下一阶段以递进式增强;
保证细节恢复的连续性和鲁棒性。
论文阅读
本文的特色就是将这些退化因素考虑在内。
原始Retinex理论:
考虑退化因素:
根据以上理论我们来看看模型和损失函数,当然,你需要知道Transformer最基础的知识。
模型
1.illumination estimator
结合代码来看
class Illumination_Estimator(nn.Module):
def __init__(
self, n_fea_middle, n_fea_in=4, n_fea_out=3): #__init__部分是内部属性,而forward的输入才是外部输入
super(Illumination_Estimator, self).__init__()
self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)
self.depth_conv = nn.Conv2d(
n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)
self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)
def forward(self, img):
# img: b,c=3,h,w
# mean_c: b,c=1,h,w
# illu_fea: b,c,h,w
# illu_map: b,c=3,h,w
mean_c = img.mean(dim=1).unsqueeze(1)
# stx()
input = torch.cat([img,mean_c], dim=1)
x_1 = self.conv1(input)
illu_fea = self.depth_conv(x_1)
illu_map = self.conv2(illu_fea)
return illu_fea, illu_map
这一个部分就是一个简单的CNN结构用来光照估计,它的步骤如下:
输入:img
:(B, 3, H, W)
,RGB图像、
计算图像的平均亮度(每个像素位置RGB的均值),得到形状 (B, 1, H, W)
。
合并成一个 4 通道的输入张量:RGB + mean_c → (B, 4, H, W)
通过深度卷积提取特征illu_fea 再illu_map = self.conv2(illu_fea)
生成光照图
最后输出return illu_fea, illu_map
illu_fea
:中间的光照特征图
illu_map
:最终输出的 RGB 光照图(可视化、增强或用于其他模块)
2.Embedding
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
主要是通过Conv将X从in_dim投影到31维
3.Encoder
这个是最为复杂主要模块就是IGAB还有上下采样
class Denoiser(nn.Module):
def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
super(Denoiser, self).__init__()
self.dim = dim
self.level = level
# Input projection
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
# Encoder
self.encoder_layers = nn.ModuleList([])
dim_level = dim
for i in range(level):
self.encoder_layers.append(nn.ModuleList([
IGAB(
dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
]))
dim_level *= 2
# Bottleneck
self.bottleneck = IGAB(
dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])
# Decoder
self.decoder_layers = nn.ModuleList([])
for i in range(level):
self.decoder_layers.append(nn.ModuleList([
nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
kernel_size=2, padding=0, output_padding=0),
nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
IGAB(
dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
heads=(dim_level // 2) // dim),
]))
dim_level //= 2
# Output projection
self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, illu_fea):
"""
x: [b,c,h,w] x是feature, 不是image
illu_fea: [b,c,h,w]
return out: [b,c,h,w]
"""
# Embedding
fea = self.embedding(x)
# Encoder
fea_encoder = []
illu_fea_list = []
for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
fea = IGAB(fea,illu_fea) # bchw
illu_fea_list.append(illu_fea)
fea_encoder.append(fea)
fea = FeaDownSample(fea)
illu_fea = IlluFeaDownsample(illu_fea)
# Bottleneck
fea = self.bottleneck(fea,illu_fea)
# Decoder
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
fea = FeaUpSample(fea)
fea = Fution(
torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
illu_fea = illu_fea_list[self.level-1-i]
fea = LeWinBlcok(fea,illu_fea)
# Mapping
out = self.mapping(fea) + x
return out
self.encoder_layers = nn.ModuleList([])
dim_level = dim
for i in range(level):
self.encoder_layers.append(nn.ModuleList([
IGAB(
dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
]))
dim_level *= 2
IGBA
class IGAB(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
num_blocks=2,
):
super().__init__()
self.blocks = nn.ModuleList([])
for _ in range(num_blocks):
self.blocks.append(nn.ModuleList([
IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
PreNorm(dim, FeedForward(dim=dim))
]))
def forward(self, x, illu_fea):
"""
x: [b,c,h,w]
illu_fea: [b,c,h,w]
return out: [b,c,h,w]
"""
x = x.permute(0, 2, 3, 1)
for (attn, ff) in self.blocks:
x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
x = ff(x) + x
out = x.permute(0, 3, 1, 2)
return out
x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
x = ff(x) + x(这里主要是一个残差网络加一个前馈网络)
IG_MSA
: Illumination-Guided Multi-head Self Attention
class IG_MSA(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
):
super().__init__()
self.num_heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
self.pos_emb = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
GELU(),
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
)
self.dim = dim
def forward(self, x_in, illu_fea_trans):
"""
x_in: [b,h,w,c] # input_feature
illu_fea: [b,h,w,c] # mask shift? 为什么是 b, h, w, c?
return out: [b,h,w,c]
"""
b, h, w, c = x_in.shape
x = x_in.reshape(b, h * w, c)
q_inp = self.to_q(x)
k_inp = self.to_k(x)
v_inp = self.to_v(x)
illu_attn = illu_fea_trans # illu_fea: b,c,h,w -> b,h,w,c
q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
(q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))#空间转化,这相当于把每个 token 的通道维 C 拆成多个注意力头(h个),每个头 dim_head 大小。
v = v * illu_attn
# q: b,heads,hw,c#的作用是 转置每个注意力头的特征矩阵的最后两个维度,为后续的矩阵乘法(attention计算)准备合适的维度。
q = q.transpose(-2, -1)
k = k.transpose(-2, -1)
v = v.transpose(-2, -1)
q = F.normalize(q, dim=-1, p=2)
k = F.normalize(k, dim=-1, p=2)
attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
attn = attn * self.rescale
attn = attn.softmax(dim=-1)
x = attn @ v # b,heads,d,hw
x = x.permute(0, 3, 1, 2) # Transpose把 HW 移到 batch 维度后面,为后续 reshape 成 [B, HW, C] 做准备。
x = x.reshape(b, h * w, self.num_heads * self.dim_head)#这是将多头输出拼接在一起,得到标准输出的维度 [B, HW, C],注意这是还未还原成空间格式的输出。
out_c = self.proj(x).view(b, h, w, c)#最后 reshape 回 [B, H, W, C],与输入 x_in 保持一致;
out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
0, 3, 1, 2)).permute(0, 2, 3, 1)#位置编码器
out = out_c + out_p
return out
流程:
1. 将 x 映射成 Q、K、V
2. 将光照特征嵌入进 V(乘法调制)
3. 执行归一化注意力计算
4. 加权求和得到输出
5. 加上位置编码(Conv-based)作为补充
需要注意的就是V:
v = v * illu_attn
这里的illu_attn就是illu_fea
Attention 计算:
attn = (k @ q^T) * rescale
attn = softmax(attn)
先归一化 q
, k
(L2);
然后计算注意力权重;
再乘以 rescale
(每个 head 可学习的缩放因子);
再用 softmax
。
位置编码不同于一般的transformer
out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(0, 3, 1, 2))
用两个深度可分离卷积(depthwise conv)构建位置嵌入;
输出形状与 x
相同,加法融合到输出中,他是可学习的参数。
这一段V阵受光照的引导,Q,K阵则为transformer的注意力机制,两者不一样
4.Bottleneck
IGAB模块重复一次
5.Decoder
上采样→Conv(压缩通道)→IGAB解码
损失函数
(1)绝对误差
(2)MSE Loss(Mean Squared Error)
(3)PSNR Loss
(4) Charbonnier Loss(平滑 L1)
工作流程
前面说它是逐步增强,就以stage=3来说明如何逐步:
Input Image
│
▼
[Stage 1]
├─ Illumination Estimator → illu_map_1, illu_fea_1
├─ Enhance Image → input_1 = img * illu_map_1 + img
└─ Denoiser(input_1, illu_fea_1) → out_1
│
▼
[Stage 2]
├─ Illumination Estimator → illu_map_2, illu_fea_2
├─ Enhance Image → input_2 = out_1 * illu_map_2 + out_1
└─ Denoiser(input_2, illu_fea_2) → out_2
│
▼
[Stage 3]
├─ Illumination Estimator → illu_map_3, illu_fea_3
├─ Enhance Image → input_3 = out_2 * illu_map_3 + out_2
└─ Denoiser(input_3, illu_fea_3) → out_3
│
▼
Output Image (out_3)
就是三个模块重复三次达到增强目的
复现
1.环境配置
打开这段代码找到如下代码依次填入路径就可以了,注意opt中的内容不需要更改,其它换成自己的即可
parser = argparse.ArgumentParser(description='Image Enhancement using Retinexformer')
parser.add_argument('--input_dir', default=r"E:\xjq\angguang\11111low_light_date\data\Eval\Eval\Huawei\low",
type=str, help='Directory of input images')
parser.add_argument('--result_dir', default=r"E:\xjq\NEW REA\Retinexformer-master\result\EVAL",
type=str, help='Directory for output results')
parser.add_argument('--opt', type=str, default=r'E:\xjq\NEW REA\Retinexformer-master\Options\RetinexFormer_NTIRE.yml',
help='Path to option YAML file.')
parser.add_argument('--weights', default=r'E:\xjq\NEW REA\Retinexformer-master\NTIRE.pth',
type=str, help='Path to weights')
parser.add_argument('--gpus', type=str, default="0", help='GPU devices.')
parser.add_argument('--self_ensemble', action='store_true', help='Use self-ensemble for better results')