audioLDM模型代码阅读(二)——HiFi-GAN模型代码分析

发布于:2025-09-01 ⋅ 阅读:(20) ⋅ 点赞:(0)

先给出完整的代码:

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils import init_weights, get_padding

LRELU_SLOPE = 0.1


class ResBlock1(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
        super(ResBlock1, self).__init__()
        self.h = h
        self.convs1 = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                               padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                               padding=get_padding(kernel_size, dilation[1]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
                               padding=get_padding(kernel_size, dilation[2])))
        ])
        self.convs1.apply(init_weights)

        self.convs2 = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1)))
        ])
        self.convs2.apply(init_weights)

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c1(xt)
            xt = F.leaky_relu(xt, LRELU_SLOPE)
            xt = c2(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            remove_weight_norm(l)
        for l in self.convs2:
            remove_weight_norm(l)


class ResBlock2(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
        super(ResBlock2, self).__init__()
        self.h = h
        self.convs = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                               padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                               padding=get_padding(kernel_size, dilation[1])))
        ])
        self.convs.apply(init_weights)

    def forward(self, x):
        for c in self.convs:
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs:
            remove_weight_norm(l)


class Generator(torch.nn.Module):
    def __init__(self, h):
        super(Generator, self).__init__()
        self.h = h
        self.num_kernels = len(h.resblock_kernel_sizes)
        self.num_upsamples = len(h.upsample_rates)
        self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
        resblock = ResBlock1 if h.resblock == '1' else ResBlock2

        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
            self.ups.append(weight_norm(
                ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
                                k, u, padding=(k-u)//2)))

        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = h.upsample_initial_channel//(2**(i+1))
            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
                self.resblocks.append(resblock(h, ch, k, d))

        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)

    def forward(self, x):
        x = self.conv_pre(x)
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, LRELU_SLOPE)
            x = self.ups[i](x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x

    def remove_weight_norm(self):
        print('Removing weight norm...')
        for l in self.ups:
            remove_weight_norm(l)
        for l in self.resblocks:
            l.remove_weight_norm()
        remove_weight_norm(self.conv_pre)
        remove_weight_norm(self.conv_post)


class DiscriminatorP(torch.nn.Module):
    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
        super(DiscriminatorP, self).__init__()
        self.period = period
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        self.convs = nn.ModuleList([
            norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
        ])
        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))

    def forward(self, x):
        fmap = []

        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0: # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiPeriodDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiPeriodDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorP(2),
            DiscriminatorP(3),
            DiscriminatorP(5),
            DiscriminatorP(7),
            DiscriminatorP(11),
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


class DiscriminatorS(torch.nn.Module):
    def __init__(self, use_spectral_norm=False):
        super(DiscriminatorS, self).__init__()
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        self.convs = nn.ModuleList([
            norm_f(Conv1d(1, 128, 15, 1, padding=7)),
            norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
            norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
            norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
            norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
            norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
            norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
        ])
        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))

    def forward(self, x):
        fmap = []
        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiScaleDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiScaleDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorS(use_spectral_norm=True),
            DiscriminatorS(),
            DiscriminatorS(),
        ])
        self.meanpools = nn.ModuleList([
            AvgPool1d(4, 2, padding=2),
            AvgPool1d(4, 2, padding=2)
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            if i != 0:
                y = self.meanpools[i-1](y)
                y_hat = self.meanpools[i-1](y_hat)
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


def feature_loss(fmap_r, fmap_g):
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):
        for rl, gl in zip(dr, dg):
            loss += torch.mean(torch.abs(rl - gl))

    return loss*2


def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    loss = 0
    r_losses = []
    g_losses = []
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
        r_loss = torch.mean((1-dr)**2)
        g_loss = torch.mean(dg**2)
        loss += (r_loss + g_loss)
        r_losses.append(r_loss.item())
        g_losses.append(g_loss.item())

    return loss, r_losses, g_losses


def generator_loss(disc_outputs):
    loss = 0
    gen_losses = []
    for dg in disc_outputs:
        l = torch.mean((1-dg)**2)
        gen_losses.append(l)
        loss += l

    return loss, gen_losses

这段代码实现了一个基于生成对抗网络(GAN)的音频生成模型,具体结构类似于GHiFi-GAN(一种高性能声码器),主要用于从梅尔频谱(Mel-spectrogram)生成原始音频波形。下面分模块详细解释:

核心组件概览

代码包含生成器(Generator)、两种鉴别器(多周期鉴别器、多尺度鉴别器)、残差块(ResBlock)及对应的损失函数,形成完整的GAN训练框架。

1. 残差块(ResBlock)

在这段音频生成模型代码中,残差块(ResBlock)是生成器提取特征的核心组件,通过残差连接(Residual Connection)缓解深层网络的梯度消失问题,同时利用膨胀卷积(Dilated Convolution)扩大感受野,更有效地捕捉音频的时序依赖关系。代码中实现了两种残差块:ResBlock1ResBlock2,下面分别详细解析。

1. ResBlock1 详解

ResBlock1是更复杂的残差块结构,包含两组卷积层,通过不同膨胀率的卷积提取多尺度特征,再通过残差连接融合输入与输出。

1.1 初始化方法(init
class ResBlock1(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
        super(ResBlock1, self).__init__()
        self.h = h  # 模型超参数配置(未直接使用,预留扩展)
        # 第一组卷积:带不同膨胀率的1D卷积
        self.convs1 = nn.ModuleList([
            weight_norm(Conv1d(
                channels,  # 输入通道数
                channels,  # 输出通道数(与输入相同,保证残差连接维度匹配)
                kernel_size,  # 卷积核大小(如3)
                stride=1,  # 步长1(不改变时间维度)
                dilation=dilation[0],  # 膨胀率(控制感受野)
                padding=get_padding(kernel_size, dilation[0])  # 自动计算填充,保证输出长度不变
            )),
            # 重复定义另外两个卷积层,使用不同膨胀率dilation[1]和dilation[2]
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                               padding=get_padding(kernel_size, dilation[1]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
                               padding=get_padding(kernel_size, dilation[2])))
        ])
        self.convs1.apply(init_weights)  # 初始化卷积层权重

        # 第二组卷积:固定膨胀率=1的1D卷积(普通卷积)
        self.convs2 = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1)))
        ])
        self.convs2.apply(init_weights)  # 初始化卷积层权重

关键细节

  • 膨胀卷积(Dilated Convolution)convs1的三个卷积层分别使用dilation=(1,3,5),膨胀率越大,感受野(Receptive Field)越大(无需增加卷积核大小即可捕捉更长时序的依赖关系),适合音频这种长时序数据。
  • 权重归一化(weight_norm):对卷积层应用权重归一化,稳定训练过程(减少梯度波动,加速收敛)。
  • 填充计算(get_padding):通过get_padding(kernel_size, dilation)自动计算填充大小,确保卷积后时间维度不变(输入输出长度相同,满足残差连接的维度匹配)。
  • 两组卷积设计convs1(膨胀卷积)负责扩大感受野提取多尺度特征,convs2(普通卷积)负责特征细化,增强特征表达能力。
1.2 前向传播(forward)
def forward(self, x):
    for c1, c2 in zip(self.convs1, self.convs2):
        xt = F.leaky_relu(x, LRELU_SLOPE)  # 激活函数(LeakyReLU,斜率0.1)
        xt = c1(xt)  # 第一组膨胀卷积
        xt = F.leaky_relu(xt, LRELU_SLOPE)  # 再次激活
        xt = c2(xt)  # 第二组普通卷积
        x = xt + x  # 残差连接:当前输出 + 原始输入
    return x

数据流动过程

  1. 输入x先通过LeakyReLU激活(引入非线性)。
  2. 经过convs1的膨胀卷积提取多尺度特征。
  3. 再次激活后,经过convs2的普通卷积细化特征。
  4. 将卷积结果xt与原始输入x相加(残差连接),得到当前残差块的输出。
  5. 重复上述过程(共3次,与convs1/convs2的长度一致),逐步强化特征。

残差连接的作用:直接将输入x加到输出xt中,避免深层网络的梯度消失(梯度可通过x直接反向传播),同时保留原始特征,增强模型对细微特征的捕捉能力。

1.3 移除权重归一化(remove_weight_norm)
def remove_weight_norm(self):
    for l in self.convs1:
        remove_weight_norm(l)
    for l in self.convs2:
        remove_weight_norm(l)

在模型推理(生成音频)阶段,移除权重归一化可减少计算量,提高推理速度(训练时需要权重归一化稳定训练,推理时无需)。

2. ResBlock2 详解

ResBlock2是简化版的残差块,仅包含一组卷积层(膨胀卷积),计算量更小,适合对效率要求较高的场景。

2.1 初始化方法(init
class ResBlock2(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
        super(ResBlock2, self).__init__()
        self.h = h  # 超参数配置(预留扩展)
        self.convs = nn.ModuleList([
            weight_norm(Conv1d(
                channels,
                channels,
                kernel_size,
                stride=1,
                dilation=dilation[0],  # 第一个膨胀率
                padding=get_padding(kernel_size, dilation[0])
            )),
            weight_norm(Conv1d(
                channels,
                channels,
                kernel_size,
                stride=1,
                dilation=dilation[1],  # 第二个膨胀率
                padding=get_padding(kernel_size, dilation[1])
            ))
        ])
        self.convs.apply(init_weights)  # 初始化权重

与ResBlock1的差异

  • 仅包含一组卷积层convs(长度为2,与dilation=(1,3)匹配),无ResBlock1中的第二组普通卷积,结构更简单。
  • 膨胀率通常较小(如(1,3)),感受野扩展更温和,计算量更低。
2.2 前向传播(forward)
def forward(self, x):
    for c in self.convs:
        xt = F.leaky_relu(x, LRELU_SLOPE)  # 激活
        xt = c(xt)  # 膨胀卷积
        x = xt + x  # 残差连接
    return x

数据流动过程

  1. 输入x通过LeakyReLU激活。
  2. 经过convs的膨胀卷积提取特征。
  3. 卷积结果xt与原始输入x相加(残差连接)。
  4. 重复上述过程(共2次,与convs的长度一致)。

简化的意义:减少卷积层数量,降低计算复杂度,同时保留残差连接的核心优势(缓解梯度消失),适合资源有限的场景或作为轻量化模型的组件。

2.3 移除权重归一化(remove_weight_norm)
def remove_weight_norm(self):
    for l in self.convs:
        remove_weight_norm(l)

ResBlock1同理,推理阶段移除权重归一化以提高效率。

3. 两种残差块的对比与应用

特性 ResBlock1 ResBlock2
卷积层组数 2组(膨胀卷积+普通卷积) 1组(仅膨胀卷积)
卷积层数量 3+3=6层 2层
感受野 更大(多组膨胀率+普通卷积) 较小(仅两组膨胀率)
计算量 较高 较低
适用场景 追求高特征表达能力(如高质量生成) 追求效率(如快速推理)

在生成器中,通过参数h.resblock选择使用ResBlock1ResBlock2,两者均作为特征提取的基本单元,在每个上采样步骤后堆叠,逐步将梅尔频谱的特征转换为音频波形的特征。

总结

残差块是该音频生成模型的核心组件,通过:

  • 残差连接:解决深层网络梯度消失问题,保留原始特征。
  • 膨胀卷积:在不增加卷积核大小的情况下扩大感受野,捕捉音频的长时序依赖。
  • 权重归一化:稳定训练过程,加速收敛。

ResBlock1ResBlock2分别从“特征表达能力”和“计算效率”角度设计,可根据实际需求选择,共同支撑生成器从梅尔频谱到音频波形的高质量转换。

2. 生成器(Generator)

生成器(Generator)是该音频生成模型的核心组件,负责将输入的80维梅尔频谱(Mel-spectrogram)转换为1维原始音频波形。其设计核心是通过多步上采样逐步扩大时间维度(从梅尔频谱的短时长相音频的长时长),并通过残差块提取和强化特征,最终输出高质量音频。以下是详细解析:

1. 生成器的初始化(__init__方法)

生成器的初始化过程定义了从输入映射、上采样、特征提取到输出映射的完整组件链,核心参数依赖于配置h(包含上采样率、卷积核大小等超参数)。

class Generator(torch.nn.Module):
    def __init__(self, h):
        super(Generator, self).__init__()
        self.h = h  # 模型超参数配置(如采样率、卷积核大小等)
        self.num_kernels = len(h.resblock_kernel_sizes)  # 每个上采样步骤对应的残差块数量
        self.num_upsamples = len(h.upsample_rates)  # 上采样总步数
        # 输入映射:将80维梅尔频谱转换为高维特征
        self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
        
        # 上采样层:通过转置卷积实现时间维度扩展
        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
            self.ups.append(weight_norm(
                ConvTranspose1d(
                    # 输入通道数:初始通道数 // 2^i(每次上采样后通道数减半)
                    h.upsample_initial_channel // (2 **i),
                    # 输出通道数:初始通道数 // 2^(i+1)
                    h.upsample_initial_channel // (2** (i + 1)),
                    kernel_size=k,  # 上采样卷积核大小
                    stride=u,  # 上采样倍数(与upsample_rates对应)
                    padding=(k - u) // 2  # 计算填充,确保上采样后时间维度正确扩展
                )
            ))
        
        # 残差块组:每个上采样步骤后接多组残差块,用于特征提取
        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            # 当前上采样步骤后的特征通道数(随上采样逐步减半)
            ch = h.upsample_initial_channel // (2 **(i + 1))
            # 为每个上采样步骤添加num_kernels个残差块
            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
                # 根据配置选择ResBlock1或ResBlock2
                resblock = ResBlock1 if h.resblock == '1' else ResBlock2
                self.resblocks.append(resblock(h, ch, k, d))
        
        # 输出映射:将高维特征转换为1维音频波形
        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
        
        # 初始化权重
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)
关键组件解析

1.输入映射(conv_pre)- 作用:将80维梅尔频谱(输入特征)映射到高维特征空间(通道数为h.upsample_initial_channel,如512),为后续特征提取做准备。

  • 实现:1D卷积(Conv1d),卷积核大小7,padding=3,确保时间维度不变(输入输出长度相同)。
  • 权重归一化:应用weight_norm稳定训练。

2.上采样层(self.ups)- 作用:通过转置卷积(ConvTranspose1d) 逐步扩大时间维度(梅尔频谱的时间步长较短,音频的时间步长较长,需通过上采样匹配)。

  • 核心参数:
    • upsample_rates:上采样倍数列表(如[8,8,2,2]),总上采样倍数为各值乘积(8×8×2×2=256,即梅尔频谱长度×256=音频长度)。
    • upsample_kernel_sizes:上采样卷积核大小(需与上采样率匹配,如[16,16,4,4]),确保通过padding计算((k-u)//2)使时间维度按u倍扩展。
  • 通道数变化:每次上采样后通道数减半(如512→256→128→64→32),平衡计算量与特征表达能力。

3.残差块组(self.resblocks)- 作用:对每个上采样步骤后的特征进行细化提取,捕捉音频的局部与全局时序依赖。

  • 组织方式:每个上采样步骤后接num_kernels个残差块(如3个),残差块类型由h.resblock决定(ResBlock1ResBlock2)。
  • 通道一致性:残差块的输入/输出通道数与当前上采样步骤后的通道数ch一致,确保残差连接有效。

4.输出映射(conv_post)- 作用:将最终的高维特征(如32通道)转换为1维音频波形。

  • 实现:1D卷积(Conv1d),卷积核大小7,padding=3,输出通道数=1。

2. 前向传播(forward方法)

前向传播定义了数据从梅尔频谱输入到音频输出的完整流动过程,核心是“上采样→残差特征提取→特征融合”的迭代过程。

def forward(self, x):
    # 步骤1:输入映射(梅尔频谱→高维特征)
    x = self.conv_pre(x)  # 形状:(batch, 80, T_mel) → (batch, initial_ch, T_mel)
    
    # 步骤2:多步上采样+残差特征提取
    for i in range(self.num_upsamples):
        x = F.leaky_relu(x, LRELU_SLOPE)  # 激活函数(引入非线性)
        x = self.ups[i](x)  # 上采样:时间维度扩大u倍,通道数减半
        
        # 多个残差块并行处理,结果平均融合
        xs = None
        for j in range(self.num_kernels):
            # 取出当前上采样步骤对应的第j个残差块
            resblock = self.resblocks[i * self.num_kernels + j]
            if xs is None:
                xs = resblock(x)  # 首次处理:直接赋值
            else:
                xs += resblock(x)  # 后续处理:累加特征
        x = xs / self.num_kernels  # 特征融合(平均):降低过拟合风险
    
    # 步骤3:输出映射(高维特征→音频波形)
    x = F.leaky_relu(x)  # 最终激活
    x = self.conv_post(x)  # 形状:(batch, ch, T_audio) → (batch, 1, T_audio)
    x = torch.tanh(x)  # 输出范围归一化到[-1, 1](音频信号常见范围)
    
    return x
数据流动细节
  • 输入阶段:输入x为梅尔频谱,形状为(batch_size, 80, T_mel)(80是梅尔频谱维度,T_mel是时间步长)。经过conv_pre后,形状变为(batch_size, initial_ch, T_mel)(如(32, 512, 100))。

  • 上采样阶段

    • 每次上采样通过self.ups[i]将时间维度扩大h.upsample_rates[i]倍(如8倍),通道数减半(如512→256)。
    • 上采样后,通过num_kernels个残差块并行处理(如3个),每个残差块输出相同形状的特征,累加后平均(xs / num_kernels),实现多尺度特征融合。
  • 输出阶段:最终特征经过conv_post转换为1通道,再通过tanh激活,输出形状为(batch_size, 1, T_audio)的音频波形(T_audio = T_mel × 总上采样倍数)。

3. 移除权重归一化(remove_weight_norm方法)

训练时为稳定收敛使用了权重归一化,但推理(生成音频)时无需,因此提供该方法移除归一化以提高效率:

def remove_weight_norm(self):
    print('Removing weight norm...')
    for l in self.ups:
        remove_weight_norm(l)  # 移除上采样层的权重归一化
    for l in self.resblocks:
        l.remove_weight_norm()  # 移除残差块的权重归一化
    remove_weight_norm(self.conv_pre)  # 移除输入映射的权重归一化
    remove_weight_norm(self.conv_post)  # 移除输出映射的权重归一化

4. 生成器设计亮点

1.渐进式上采样:通过多步小倍数上采样(而非一步大倍数上采样),避免直接扩展导致的特征模糊,逐步恢复高频细节。
2.残差特征融合:每个上采样步骤后用多个残差块并行处理并平均结果,融合多尺度特征,增强生成音频的丰富性。
3.膨胀卷积应用:残差块中使用膨胀卷积,在不增加计算量的情况下扩大感受野,有效捕捉音频的长时序依赖(如语音中的上下文信息)。
4.权重归一化:稳定训练过程,减少梯度波动,使深层网络更容易收敛。

总结

生成器通过“输入映射→多步上采样+残差特征提取→输出映射”的流程,将梅尔频谱转换为音频波形。核心设计围绕“渐进式扩展时间维度”和“多尺度特征融合”,结合残差连接和膨胀卷积,在保证生成质量的同时,平衡了计算效率与训练稳定性。这一结构使其特别适合作为语音合成系统中的声码器(Vocoder),生成高保真、自然的音频。

3. 鉴别器(Discriminator)

在该音频生成模型中,鉴别器(Discriminator)的核心作用是区分“真实音频”和“生成器输出的伪造音频”,通过与生成器的对抗训练,推动生成器生成更逼真的音频。代码中设计了两种互补的鉴别器结构:多周期鉴别器(MultiPeriodDiscriminator)多尺度鉴别器(MultiScaleDiscriminator),从不同角度捕捉音频特征,增强判别能力。以下是详细解析:

1. 多周期鉴别器(MultiPeriodDiscriminator)

多周期鉴别器通过周期子序列分割捕捉音频的周期性模式(如语音的基频、音乐的节奏等),从多个周期尺度鉴别音频真实性。它包含多个子鉴别器DiscriminatorP,每个子鉴别器专注于特定周期的特征。

1.1 子鉴别器(DiscriminatorP)

DiscriminatorP是多周期鉴别器的基本单元,针对特定周期period处理音频,将1D音频转换为2D周期特征后进行判别。

初始化(init
class DiscriminatorP(torch.nn.Module):
    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
        super(DiscriminatorP, self).__init__()
        self.period = period  # 周期(如2、3、5等,用于分割音频为子序列)
        # 选择权重归一化或谱归一化(谱归一化更适合稳定GAN训练)
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        
        # 2D卷积层序列:逐步提取周期特征,通道数从1→32→128→512→1024
        self.convs = nn.ModuleList([
            norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),  # 步长1,不改变空间维度
        ])
        # 输出层:将特征映射为判别分数(真实/伪造)
        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))

关键设计

  • 周期分割:针对特定周期period(如2),将音频分割为长度为period的子序列(如音频[t0,t1,t2,t3]按周期2分割为[[t0,t1], [t2,t3]]),转换为2D特征(形状:(batch, 1, 子序列数, period)),便于捕捉周期性模式。
  • 2D卷积:使用Conv2d处理周期特征,卷积核形状为(kernel_size, 1),仅在子序列数维度(时间方向)滑动,保留周期内的时序关系。
  • 归一化选择:支持weight_norm(权重归一化)和spectral_norm(谱归一化),后者通过限制权重矩阵的谱范数,更能稳定GAN训练。
前向传播(forward)
def forward(self, x):
    fmap = []  # 存储各层特征图(用于特征匹配损失)
    
    # 步骤1:1D音频→2D周期特征(按周期分割)
    b, c, t = x.shape  # x形状:(batch, 1, 时间步长)
    if t % self.period != 0:  # 补零使时间步长为周期的整数倍
        n_pad = self.period - (t % self.period)
        x = F.pad(x, (0, n_pad), "reflect")  # 反射补零(减少边界效应)
        t = t + n_pad
    # 重塑为2D:(batch, 1, 子序列数, 周期长度)
    x = x.view(b, c, t // self.period, self.period)
    
    # 步骤2:通过卷积层提取特征并记录特征图
    for l in self.convs:
        x = l(x)  # 卷积操作
        x = F.leaky_relu(x, LRELU_SLOPE)  # 激活函数(LeakyReLU,斜率0.1)
        fmap.append(x)  # 保存当前层特征图
    
    # 步骤3:输出判别分数
    x = self.conv_post(x)  # 映射为判别分数(形状:(batch, 1, ...))
    fmap.append(x)  # 保存输出层特征图
    x = torch.flatten(x, 1, -1)  # 展平为(batch, 分数)
    
    return x, fmap  # 返回判别分数和特征图列表

核心流程

  1. 周期转换:将1D音频转换为2D周期特征,突出周期性模式(如语音的基频周期)。
  2. 特征提取:通过多组2D卷积逐步提升通道数(1→1024),压缩时间维度(子序列数减少),捕捉高层周期特征。
  3. 判别输出:最终通过conv_post输出判别分数(值越大越可能是真实音频),同时记录各层特征图用于后续损失计算。
1.2 多周期鉴别器整体(MultiPeriodDiscriminator)

多周期鉴别器由多个不同周期的DiscriminatorP组成,从多个周期尺度(2、3、5、7、11)联合判别,覆盖音频中不同频率的周期性模式。

class MultiPeriodDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiPeriodDiscriminator, self).__init__()
        # 包含5个不同周期的子鉴别器(周期2、3、5、7、11)
        self.discriminators = nn.ModuleList([
            DiscriminatorP(2),
            DiscriminatorP(3),
            DiscriminatorP(5),
            DiscriminatorP(7),
            DiscriminatorP(11),
        ])
    
    def forward(self, y, y_hat):
        # y:真实音频;y_hat:生成器输出的伪造音频
        y_d_rs = []  # 真实音频的判别分数列表
        y_d_gs = []  # 伪造音频的判别分数列表
        fmap_rs = []  # 真实音频的特征图列表
        fmap_gs = []  # 伪造音频的特征图列表
        
        # 每个子鉴别器分别处理真实和伪造音频
        for i, d in enumerate(self.discriminators):
            y_d_r, fmap_r = d(y)  # 真实音频的判别结果
            y_d_g, fmap_g = d(y_hat)  # 伪造音频的判别结果
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)
        
        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

设计目的:不同音频(如语音、音乐)的周期性模式不同(如语音基频约50-500Hz,对应周期20-2ms),使用多个周期(2、3、5、7、11)可覆盖更广泛的周期范围,避免单一周期的判别偏差,提升整体判别能力。

2. 多尺度鉴别器(MultiScaleDiscriminator)

多尺度鉴别器通过下采样生成不同时间尺度的音频,从多个分辨率(原始、1/2、1/4)捕捉音频的局部细节和全局结构,与多周期鉴别器形成互补。它包含多个子鉴别器DiscriminatorS,每个子鉴别器处理特定尺度的音频。

2.1 子鉴别器(DiscriminatorS)

DiscriminatorS是多尺度鉴别器的基本单元,使用1D卷积直接处理音频,通过步长和分组卷积提取不同尺度的特征。

初始化(init
class DiscriminatorS(torch.nn.Module):
    def __init__(self, use_spectral_norm=False):
        super(DiscriminatorS, self).__init__()
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        
        # 1D卷积层序列:逐步提取时序特征,通道数从1→128→256→512→1024
        self.convs = nn.ModuleList([
            norm_f(Conv1d(1, 128, 15, 1, padding=7)),  # 步长1,不改变时间维度
            norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),  # 步长2下采样,分组卷积
            norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),  # 步长2下采样
            norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),  # 步长4下采样
            norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),  # 步长4下采样
            norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),  # 步长1
            norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),  # 步长1,小卷积核细化特征
        ])
        # 输出层:映射为判别分数
        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))

关键设计

  • 1D卷积直接处理:无需周期分割,直接对1D音频进行卷积,更侧重捕捉时序连续性特征(如音频的瞬态变化)。
  • 下采样与分组卷积:通过步长(2、4)实现时间维度下采样(降低分辨率),同时使用分组卷积(groups)减少参数计算量,增强特征多样性。
  • 多尺度特征:卷积核大小从15→41→5,结合不同步长,捕捉从局部到全局的时序特征。
前向传播(forward)
def forward(self, x):
    fmap = []  # 存储各层特征图
    
    # 步骤1:通过卷积层提取特征并记录特征图
    for l in self.convs:
        x = l(x)  # 卷积操作
        x = F.leaky_relu(x, LRELU_SLOPE)  # 激活函数
        fmap.append(x)  # 保存当前层特征图
    
    # 步骤2:输出判别分数
    x = self.conv_post(x)  # 映射为判别分数
    fmap.append(x)  # 保存输出层特征图
    x = torch.flatten(x, 1, -1)  # 展平为(batch, 分数)
    
    return x, fmap  # 返回判别分数和特征图列表

核心流程:直接对1D音频进行多步卷积和下采样,逐步压缩时间维度、提升通道数,捕捉不同尺度的时序特征,最终输出判别分数和特征图。

2.2 多尺度鉴别器整体(MultiScaleDiscriminator)

多尺度鉴别器由3个DiscriminatorS组成,通过平均池化生成不同尺度的音频(原始、1/2、1/4),从粗到细覆盖音频的全局和局部特征。

class MultiScaleDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiScaleDiscriminator, self).__init__()
        # 3个子鉴别器(第1个使用谱归一化,增强稳定性)
        self.discriminators = nn.ModuleList([
            DiscriminatorS(use_spectral_norm=True),
            DiscriminatorS(),
            DiscriminatorS(),
        ])
        # 平均池化层:用于生成低尺度音频(1/2、1/4)
        self.meanpools = nn.ModuleList([
            AvgPool1d(4, 2, padding=2),  # 下采样1/2(核4,步长2)
            AvgPool1d(4, 2, padding=2)   # 再下采样1/2(总1/4)
        ])
    
    def forward(self, y, y_hat):
        y_d_rs = []  # 真实音频的判别分数列表
        y_d_gs = []  # 伪造音频的判别分数列表
        fmap_rs = []  # 真实音频的特征图列表
        fmap_gs = []  # 伪造音频的特征图列表
        
        # 每个子鉴别器处理不同尺度的音频
        for i, d in enumerate(self.discriminators):
            if i != 0:  # 第1个处理原始尺度,第2/3个处理下采样后的尺度
                y = self.meanpools[i-1](y)  # 真实音频下采样
                y_hat = self.meanpools[i-1](y_hat)  # 伪造音频下采样
            # 子鉴别器处理当前尺度的音频
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)
        
        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

设计目的:音频的高频细节(如瞬态音)和低频结构(如整体节奏)需要不同分辨率的特征捕捉。通过平均池化生成1/2、1/4尺度的音频,使子鉴别器专注于不同频率范围的特征,提升对细微差异的判别能力。

3. 鉴别器的协同作用与训练目标

两种鉴别器(多周期+多尺度)从不同角度判别音频真实性:

  • 多周期鉴别器:聚焦周期性模式(如语音基频、音乐节拍),擅长捕捉“韵律一致性”。
  • 多尺度鉴别器:聚焦时序连续性(如音频的平滑过渡、瞬态变化),擅长捕捉“细节真实性”。

训练时,鉴别器的目标是最大化对真实音频的判别分数(接近1),最小化对伪造音频的判别分数(接近0);而生成器则通过对抗训练,尝试欺骗鉴别器(使伪造音频的判别分数接近1)。同时,鉴别器输出的特征图用于计算“特征匹配损失”,进一步约束生成音频的特征分布与真实音频一致。

总结

鉴别器通过“多周期+多尺度”的组合设计,从周期性和时序连续性两个维度全面判别音频真实性:

  • 每个子鉴别器通过卷积层提取特征,输出判别分数和特征图。
  • 多周期设计覆盖不同周期模式,多尺度设计覆盖不同分辨率特征。
  • 与生成器的对抗训练推动生成器生成更逼真、细节更丰富的音频。

这种结构是高性能音频生成模型(如GHiFi-GAN)的核心设计,使其能够生成接近真实的高保真音频。

4. 损失函数

在该音频生成GAN模型中,损失函数是连接生成器(Generator)和鉴别器(Discriminator)的核心,通过对抗训练推动双方迭代优化。代码中定义了三类关键损失函数:** 特征匹配损失(feature_loss) 鉴别器损失(discriminator_loss) 生成器损失(generator_loss)**,它们协同作用以确保生成音频的真实性和质量。以下是详细解析:

1. 特征匹配损失(feature_loss)

特征匹配损失用于约束生成音频的特征分布与真实音频一致,补充对抗损失的不足(仅靠对抗损失可能导致生成样本“骗过”鉴别器但特征不真实)。它通过计算真实音频和生成音频在鉴别器各层特征图的差异,引导生成器学习更细腻的特征。

def feature_loss(fmap_r, fmap_g):
    loss = 0
    # 遍历所有鉴别器的特征图列表(多周期+多尺度鉴别器的特征图)
    for dr, dg in zip(fmap_r, fmap_g):
        # 遍历单个鉴别器内的各层特征图
        for rl, gl in zip(dr, dg):
            # 计算当前层特征图的L1损失(平均绝对误差)
            loss += torch.mean(torch.abs(rl - gl))
    # 缩放损失值(经验系数,增强该损失的权重)
    return loss * 2
关键细节
  • 输入fmap_r是真实音频经过鉴别器后输出的特征图列表,fmap_g是生成音频经过相同鉴别器后输出的特征图列表(包含多周期和多尺度鉴别器的所有层特征)。
  • 计算方式:对每一层特征图(rl为真实特征,gl为生成特征)计算L1损失(torch.abs(rl - gl)的均值),累加所有层的损失后乘以2(缩放系数,平衡与其他损失的权重)。
  • 作用:强制生成音频在鉴别器的中间特征层面与真实音频相似,避免生成器仅优化“骗过鉴别器”的表层特征,而忽略音频的细节结构(如频谱分布、时序连贯性)。

2. 鉴别器损失(discriminator_loss)

鉴别器的目标是最大化对“真实音频”的判别分数(接近1),同时最小化对“生成音频”的判别分数(接近0)。该损失函数量化了鉴别器的分类误差,指导其优化以更好地区分真假音频。

def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    loss = 0
    r_losses = []  # 记录每个鉴别器对真实音频的损失
    g_losses = []  # 记录每个鉴别器对生成音频的损失
    # 遍历所有鉴别器的输出(多周期+多尺度鉴别器)
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
        # 真实音频损失:希望判别分数dr接近1,用(1-dr)^2衡量偏差
        r_loss = torch.mean((1 - dr) **2)
        # 生成音频损失:希望判别分数dg接近0,用dg^2衡量偏差
        g_loss = torch.mean(dg** 2)
        # 累加单个鉴别器的总损失
        loss += (r_loss + g_loss)
        # 记录单个鉴别器的损失值(用于监控训练过程)
        r_losses.append(r_loss.item())
        g_losses.append(g_loss.item())
    return loss, r_losses, g_losses
关键细节
  • 输入disc_real_outputs是所有鉴别器对真实音频的判别分数列表,disc_generated_outputs是所有鉴别器对生成音频的判别分数列表。
  • 计算方式
    • 对真实音频:使用平方损失(1 - dr)^2,当dr=1时损失为0(完美判别),dr越小损失越大。
    • 对生成音频:使用平方损失dg^2,当dg=0时损失为0(完美判别),dg越大损失越大。
    • 总损失为所有鉴别器的真实损失与生成损失之和。
  • 作用:推动鉴别器学习真实音频与生成音频的差异,提升分类能力。每个鉴别器(多周期/多尺度)的损失被单独记录,便于监控不同鉴别器的训练状态。

3. 生成器损失(generator_loss)

生成器的目标是“欺骗”鉴别器,使鉴别器对生成音频的判别分数接近1。该损失函数量化了生成器的欺骗效果,指导其优化以生成更逼真的音频。

def generator_loss(disc_outputs):
    loss = 0
    gen_losses = []  # 记录每个鉴别器上的生成损失
    # 遍历所有鉴别器对生成音频的输出
    for dg in disc_outputs:
        # 生成损失:希望判别分数dg接近1,用(1-dg)^2衡量偏差
        l = torch.mean((1 - dg) **2)
        gen_losses.append(l)
        loss += l
    return loss, gen_losses

####** 关键细节 - 输入 disc_outputs是所有鉴别器对生成音频的判别分数列表(与disc_generated_outputs一致)。
-
计算方式 :对每个鉴别器的生成音频判别分数dg,使用平方损失(1 - dg)^2,当dg=1时损失为0(完美欺骗),dg越小损失越大。总损失为所有鉴别器的生成损失之和。
-
作用 **:推动生成器优化输出,使生成音频在所有鉴别器(多周期/多尺度)上都被误认为真实音频,迫使生成器学习真实音频的全面特征(周期性、时序连续性等)。

###** 4. 损失函数的协同作用 在实际训练中,三类损失函数通过以下方式协同工作:
1.
鉴别器优化 :单独最小化discriminator_loss,使其能更准确地区分真假音频。
2.
生成器优化 **:最小化“生成器损失 + 特征匹配损失”(通常特征匹配损失会乘以一个权重系数,如10),既要求生成音频能欺骗鉴别器(对抗目标),又要求其特征分布接近真实音频(特征匹配目标)。

这种组合避免了GAN训练中常见的“模式崩溃”(生成样本多样性不足)和“训练不稳定”问题,同时保证了生成音频的高质量(细节丰富、真实感强)。

###** 总结 - 特征匹配损失 :从特征层面约束生成音频与真实音频的一致性,提升细节质量。
-
鉴别器损失 :指导鉴别器学习真假音频的差异,增强判别能力。
-
生成器损失 **:指导生成器欺骗鉴别器,推动生成更逼真的音频。

三者协同形成了完整的训练目标,使生成器能够逐步学习真实音频的分布,最终生成高保真、自然的音频波形。

总结

该代码实现了一个高性能音频生成模型,核心设计包括:

  • 生成器:通过多步上采样和残差块,从梅尔频谱生成音频。
  • 鉴别器:多周期+多尺度设计,从不同角度区分真假音频,增强判别能力。
  • 损失函数:结合对抗损失和特征匹配损失,平衡生成质量和训练稳定性。

这种结构常用于语音合成系统中的声码器(如TTS中的最后一步:从梅尔频谱生成波形),能生成高质量、高保真的音频。