ICCV低光照增强网络Retinexform论文阅读及代码复现

发布于:2025-07-05 ⋅ 阅读:(22) ⋅ 点赞:(0)

前言

        本文从原理和代码介绍低照度增强领域中比较新的一篇论文——Retinexformer,其效果不错,刷新了十三大暗光增强效果榜单。

论文名称:Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement

论文地址[2303.06705] Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement[2303.06705] Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement

代码地址https://github.com/caiyuanhao1998/Retinexformer

主要贡献

1. 首次将 Retinex 理论引入 Transformer 框架
  • 传统的 Retinex 理论将图像建模为 反射分量 × 光照分量(Reflectance × Illumination)

  • RetinexFormer 将这一思想模块化为结构化网络,包括光照估计、去噪和重建模块;

  • 构建了Stage-wise 分层架构,分阶段提取和增强光照信息。

2. 提出 Illumination Estimation 模块 + Transformer-based IGAB(Information-Guided Aggregation Block)
  • Illumination Estimator 提取低光图像的光照图;

  • IGAB 利用注意力机制自适应融合局部与全局信息,引导去噪与细节重建;

  • 提高了模型对复杂光照分布的建模能力,尤其在极暗场景下仍能保持结构细节。

3. 提出 Multi-Stage Design 以逐步增强图像
  • 利用多阶段策略:每一阶段都包含光照估计、增强和重建过程;

  • 上一阶段的特征与图像结果被输入下一阶段以递进式增强;

  • 保证细节恢复的连续性和鲁棒性。

论文阅读

本文的特色就是将这些退化因素考虑在内。

原始Retinex理论:

考虑退化因素:

根据以上理论我们来看看模型和损失函数,当然,你需要知道Transformer最基础的知识。

模型

1.illumination estimator

结合代码来看

class Illumination_Estimator(nn.Module):
    def __init__(
            self, n_fea_middle, n_fea_in=4, n_fea_out=3):  #__init__部分是内部属性,而forward的输入才是外部输入
        super(Illumination_Estimator, self).__init__()

        self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)

        self.depth_conv = nn.Conv2d(
            n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)

        self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)

    def forward(self, img):
        # img:        b,c=3,h,w
        # mean_c:     b,c=1,h,w
        
        # illu_fea:   b,c,h,w
        # illu_map:   b,c=3,h,w
        
        mean_c = img.mean(dim=1).unsqueeze(1)
        # stx()
        input = torch.cat([img,mean_c], dim=1)

        x_1 = self.conv1(input)
        illu_fea = self.depth_conv(x_1)
        illu_map = self.conv2(illu_fea)
        return illu_fea, illu_map

这一个部分就是一个简单的CNN结构用来光照估计,它的步骤如下:

输入:img(B, 3, H, W),RGB图像、

计算图像的平均亮度(每个像素位置RGB的均值),得到形状 (B, 1, H, W)

合并成一个 4 通道的输入张量:RGB + mean_c → (B, 4, H, W)

通过深度卷积提取特征illu_fea 再illu_map = self.conv2(illu_fea)生成光照图

最后输出return illu_fea, illu_map

 illu_fea:中间的光照特征图

 illu_map:最终输出的 RGB 光照图(可视化、增强或用于其他模块)

2.Embedding

self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)

主要是通过Conv将X从in_dim投影到31维

3.Encoder

这个是最为复杂主要模块就是IGAB还有上下采样

class Denoiser(nn.Module):
    def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
        super(Denoiser, self).__init__()
        self.dim = dim
        self.level = level

        # Input projection
        self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)

        # Encoder
        self.encoder_layers = nn.ModuleList([])
        dim_level = dim
        for i in range(level):
            self.encoder_layers.append(nn.ModuleList([
                IGAB(
                    dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
            ]))
            dim_level *= 2

        # Bottleneck
        self.bottleneck = IGAB(
            dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])

        # Decoder
        self.decoder_layers = nn.ModuleList([])
        for i in range(level):
            self.decoder_layers.append(nn.ModuleList([
                nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
                                   kernel_size=2, padding=0, output_padding=0),
                nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
                IGAB(
                    dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
                    heads=(dim_level // 2) // dim),
            ]))
            dim_level //= 2

        # Output projection
        self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, illu_fea):
        """
        x:          [b,c,h,w]         x是feature, 不是image
        illu_fea:   [b,c,h,w]
        return out: [b,c,h,w]
        """

        # Embedding
        fea = self.embedding(x)

        # Encoder
        fea_encoder = []
        illu_fea_list = []
        for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
            fea = IGAB(fea,illu_fea)  # bchw
            illu_fea_list.append(illu_fea)
            fea_encoder.append(fea)
            fea = FeaDownSample(fea)
            illu_fea = IlluFeaDownsample(illu_fea)

        # Bottleneck
        fea = self.bottleneck(fea,illu_fea)

        # Decoder
        for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(
                torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
            illu_fea = illu_fea_list[self.level-1-i]
            fea = LeWinBlcok(fea,illu_fea)

        # Mapping
        out = self.mapping(fea) + x

        return out
self.encoder_layers = nn.ModuleList([])
        dim_level = dim
        for i in range(level):
            self.encoder_layers.append(nn.ModuleList([
                IGAB(
                    dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
            ]))
            dim_level *= 2

IGBA

class IGAB(nn.Module):
    def __init__(
            self,
            dim,
            dim_head=64,
            heads=8,
            num_blocks=2,
    ):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(nn.ModuleList([
                IG_MSA(dim=dim, dim_head=dim_head, heads=heads),
                PreNorm(dim, FeedForward(dim=dim))
            ]))

    def forward(self, x, illu_fea):
        """
        x: [b,c,h,w]
        illu_fea: [b,c,h,w]
        return out: [b,c,h,w]
        """
        x = x.permute(0, 2, 3, 1)
        for (attn, ff) in self.blocks:
            x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
            x = ff(x) + x
        out = x.permute(0, 3, 1, 2)
        return out

         x = attn(x, illu_fea_trans=illu_fea.permute(0, 2, 3, 1)) + x
                    x = ff(x) + x(这里主要是一个残差网络加一个前馈网络)

IG_MSA: Illumination-Guided Multi-head Self Attention

class IG_MSA(nn.Module):
    def __init__(
            self,
            dim,
            dim_head=64,
            heads=8,
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Linear(dim_head * heads, dim, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim

    def forward(self, x_in, illu_fea_trans):
        """
        x_in: [b,h,w,c]         # input_feature
        illu_fea: [b,h,w,c]         # mask shift? 为什么是 b, h, w, c?
        return out: [b,h,w,c]
        """
        b, h, w, c = x_in.shape
        x = x_in.reshape(b, h * w, c)
        q_inp = self.to_q(x)
        k_inp = self.to_k(x)
        v_inp = self.to_v(x)
        illu_attn = illu_fea_trans # illu_fea: b,c,h,w -> b,h,w,c
        q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
                                 (q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))#空间转化,这相当于把每个 token 的通道维 C 拆成多个注意力头(h个),每个头 dim_head 大小。
        v = v * illu_attn
        # q: b,heads,hw,c#的作用是 转置每个注意力头的特征矩阵的最后两个维度,为后续的矩阵乘法(attention计算)准备合适的维度。
        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        attn = (k @ q.transpose(-2, -1))   # A = K^T*Q
        attn = attn * self.rescale
        attn = attn.softmax(dim=-1)
        x = attn @ v   # b,heads,d,hw
        x = x.permute(0, 3, 1, 2)    # Transpose把 HW 移到 batch 维度后面,为后续 reshape 成 [B, HW, C] 做准备。
        x = x.reshape(b, h * w, self.num_heads * self.dim_head)#这是将多头输出拼接在一起,得到标准输出的维度 [B, HW, C],注意这是还未还原成空间格式的输出。
        out_c = self.proj(x).view(b, h, w, c)#最后 reshape 回 [B, H, W, C],与输入 x_in 保持一致;
        out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
            0, 3, 1, 2)).permute(0, 2, 3, 1)#位置编码器
        out = out_c + out_p

        return out

流程:
    1. 将 x 映射成 Q、K、V
    2. 将光照特征嵌入进 V(乘法调制)
    3. 执行归一化注意力计算
    4. 加权求和得到输出
    5. 加上位置编码(Conv-based)作为补充
需要注意的就是V:

v = v * illu_attn

这里的illu_attn就是illu_fea

Attention 计算:

attn = (k @ q^T) * rescale
attn = softmax(attn)

先归一化 q, k(L2);

然后计算注意力权重;

再乘以 rescale(每个 head 可学习的缩放因子);

再用 softmax

位置编码不同于一般的transformer

out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(0, 3, 1, 2))

用两个深度可分离卷积(depthwise conv)构建位置嵌入;

输出形状与 x 相同,加法融合到输出中,他是可学习的参数。

这一段V阵受光照的引导,Q,K阵则为transformer的注意力机制,两者不一样

4.Bottleneck

IGAB模块重复一次

5.Decoder

上采样→Conv(压缩通道)→IGAB解码

损失函数

(1)绝对误差

(2)MSE Loss(Mean Squared Error)

(3)PSNR Loss

(4) Charbonnier Loss(平滑 L1)

工作流程

前面说它是逐步增强,就以stage=3来说明如何逐步:

Input Image
   │
   ▼
[Stage 1]
   ├─ Illumination Estimator        → illu_map_1, illu_fea_1
   ├─ Enhance Image                 → input_1 = img * illu_map_1 + img
   └─ Denoiser(input_1, illu_fea_1) → out_1
   │
   ▼
[Stage 2]
   ├─ Illumination Estimator        → illu_map_2, illu_fea_2
   ├─ Enhance Image                 → input_2 = out_1 * illu_map_2 + out_1
   └─ Denoiser(input_2, illu_fea_2) → out_2
   │
   ▼
[Stage 3]
   ├─ Illumination Estimator        → illu_map_3, illu_fea_3
   ├─ Enhance Image                 → input_3 = out_2 * illu_map_3 + out_2
   └─ Denoiser(input_3, illu_fea_3) → out_3
   │
   ▼
Output Image (out_3)

就是三个模块重复三次达到增强目的

复现

1.环境配置

打开这段代码找到如下代码依次填入路径就可以了,注意opt中的内容不需要更改,其它换成自己的即可

parser = argparse.ArgumentParser(description='Image Enhancement using Retinexformer')
parser.add_argument('--input_dir', default=r"E:\xjq\angguang\11111low_light_date\data\Eval\Eval\Huawei\low",
                    type=str, help='Directory of input images')
parser.add_argument('--result_dir', default=r"E:\xjq\NEW REA\Retinexformer-master\result\EVAL",
                    type=str, help='Directory for output results')
parser.add_argument('--opt', type=str, default=r'E:\xjq\NEW REA\Retinexformer-master\Options\RetinexFormer_NTIRE.yml',
                    help='Path to option YAML file.')
parser.add_argument('--weights', default=r'E:\xjq\NEW REA\Retinexformer-master\NTIRE.pth',
                    type=str, help='Path to weights')
parser.add_argument('--gpus', type=str, default="0", help='GPU devices.')
parser.add_argument('--self_ensemble', action='store_true', help='Use self-ensemble for better results')

结果