从GAN到WGAN-GP:生成对抗网络的进化之路与实战详解
在深度学习的生成模型领域,GAN (Generative Adversarial Networks)无疑是最耀眼的明星之一。从2014年 Ian Goodfellow 提出 GAN 至今,它已经经历了无数次的迭代和进化。其中,WGAN和WGAN-GP是两次里程碑式的改进,它们从数学原理上解决了原始 GAN 训练不稳定、模式崩塌等“顽疾”。
本文将深入浅出地梳理从 GAN 到 WGAN 再到 WGAN-GP 的演进逻辑,分析它们背后的数学直觉,并提供核心代码实现。
一、GAN:天才的博弈
1.1 基本原理
GAN 的灵感来源于博弈论。它由两个网络组成:
- 生成器 (Generator, G):负责制造“假钞”(生成数据)。它的目标是生成尽可能逼真的数据,以骗过判别器。
- 判别器 (Discriminator, D):负责充当“验钞机”。它的目标是尽可能准确地分辨出输入数据是真实的(来自数据集)还是假的(由 G 生成)。
两者的目标函数是一个Min-Max 博弈:
min G max D V ( D , G ) = E x ∼ P d a t a ( x ) [ log D ( x ) ] + E z ∼ P z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim P_{data}(x)} [\log D(x)] + \mathbb{E}_{z \sim P_{z}(z)} [\log (1 - D(G(z)))]GminDmaxV(D,G)=Ex∼Pdata(x)[logD(x)]+Ez∼Pz(z)[log(1−D(G(z)))]
1.2 GAN 的阿喀琉斯之踵
虽然 GAN 的思想非常精妙,但在实际训练中,研究者们发现它非常难训练,主要面临以下问题:
- 训练不稳定:G 和 D 需要小心翼翼地平衡。如果 D 太强,G 的梯度会消失;如果 D 太弱,G 又学不到东西。
- 模式崩塌 (Mode Collapse):G 发现生成某一种特定的样本特别容易骗过 D,于是它就只生成这一种样本,失去了多样性。
- 无法指示训练进程:GAN 的 Loss 值通常震荡剧烈,无法像监督学习那样通过 Loss 下降来判断模型是否变好。
根本原因:原始 GAN 等价于在该小 JS 散度 (Jensen-Shannon Divergence)。当真实分布P r P_rPr和生成分布P g P_gPg重叠很少甚至不重叠时(在高维空间中这很常见),JS 散度是常数,导致梯度消失,G 无法获得有效的更新方向。
二、WGAN:推土机距离的救赎
为了解决 GAN 的问题,2017年 Arjovsky 等人提出了Wasserstein GAN (WGAN)。
2.1 核心思想:Wasserstein 距离
WGAN 引入了Wasserstein 距离(也称 Earth-Mover Distance,EM 距离,推土机距离)。
简单来说,如果把两个分布看作是两堆土,EM 距离就是把一堆土搬到另一堆土的位置所消耗的最小“功”(质量 $ imes$ 距离)。
优势:即使两个分布完全不重叠,Wasserstein 距离仍然能提供平滑的梯度,指引 G 慢慢向真实分布靠拢。这彻底解决了梯度消失的问题。
2.2 WGAN 的改进点
为了近似计算 Wasserstein 距离,WGAN 做了以下改动:
- 判别器变身“评论家” (Critic):D 的最后一层去掉 Sigmoid,不再输出概率,而是输出一个实数值(评分)。
- Loss 改变:
- L D = E x ~ ∼ P g [ D ( x ~ ) ] − E x ∼ P r [ D ( x ) ] L_D = \mathbb{E}_{\tilde{x} \sim P_g}[D(\tilde{x})] - \mathbb{E}_{x \sim P_r}[D(x)]LD=Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]
- L G = − E x ~ ∼ P g [ D ( x ~ ) ] L_G = -\mathbb{E}_{\tilde{x} \sim P_g}[D(\tilde{x})]LG=−Ex~∼Pg[D(x~)]
- 注意:不再取 log。
- 权重剪枝 (Weight Clipping):为了满足 Wasserstein 距离成立的数学条件(1-Lipschitz 连续性),WGAN 强制将 Critic 的所有参数限制在[ − c , c ] [-c, c][−c,c]之间(例如 c=0.01)。
2.3 WGAN 的局限
虽然 WGAN 解决了训练稳定性问题,并且 Loss 值终于可以代表图像质量了,但Weight Clipping过于简单粗暴:
- 它限制了 Critic 的表达能力。
- 容易导致参数集中在截断边界(-c 和 c),不仅浪费了神经网络的拟合能力,还可能引发梯度爆炸或消失。
三、WGAN-GP:梯度惩罚的优雅
为了解决 Weight Clipping 的副作用,Gulrajani 等人提出了WGAN-GP (WGAN with Gradient Penalty)。
3.1 核心改进:梯度惩罚
WGAN-GP 依然沿用了 WGAN 的架构和 Loss,但改变了实现 1-Lipschitz 约束的方法。
它不再直接剪裁参数,而是在 Critic 的 Loss 函数中增加了一个梯度惩罚项 (Gradient Penalty)。根据数学推导,如果一个函数是 1-Lipschitz 的,那么它的梯度范数应该在处处不超过 1。WGAN-GP 鼓励梯度的范数接近 1。
3.2 Loss 函数详解
新的 Critic Loss 如下:
L = E x ~ ∼ P g [ D ( x ~ ) ] − E x ∼ P r [ D ( x ) ] ⏟ 原始 WGAN Loss + λ E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] ⏟ Gradient Penalty L = \underbrace{\mathbb{E}_{\tilde{x} \sim P_g}[D(\tilde{x})] - \mathbb{E}_{x \sim P_r}[D(x)]}_{\text{原始 WGAN Loss}} + \lambda \underbrace{\mathbb{E}_{\hat{x} \sim P_{\hat{x}}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]}_{\text{Gradient Penalty}}L=原始WGAN LossEx~∼Pg[D(x~)]−Ex∼Pr[D(x)]+λGradient PenaltyEx^∼Px^[(∣∣∇x^D(x^)∣∣2−1)2]
其中:
- λ \lambdaλ是惩罚系数(通常取 10)。
- x ^ \hat{x}x^是采样点。我们在真实样本x xx和生成样本x ~ \tilde{x}x~之间随机插值采样:x ^ = ϵ x + ( 1 − ϵ ) x ~ \hat{x} = \epsilon x + (1-\epsilon) \tilde{x}x^=ϵx+(1−ϵ)x~,ϵ ∼ U [ 0 , 1 ] \epsilon \sim U[0, 1]ϵ∼U[0,1]。我们约束这些插值点上的梯度范数接近 1。
四、PyTorch 核心代码实现
下面是 WGAN-GP 核心部分的 PyTorch 实现代码。
4.1 梯度惩罚计算函数
importtorchimporttorch.nnasnnimporttorch.autogradasautograddefcompute_gradient_penalty(D,real_samples,fake_samples,device):""" 计算 WGAN-GP 的梯度惩罚项 """# 1. 随机权重插值alpha=torch.rand(real_samples.size(0),1,1,1).to(device)# 假设输入是图片 (N, C, H, W),需要根据维度调整 alpha 的形状interpolates=(alpha*real_samples+(1-alpha)*fake_samples).requires_grad_(True)# 2. 将插值样本输入判别器d_interpolates=D(interpolates)# 3. 计算判别器输出相对于插值样本的梯度fake=torch.ones(real_samples.shape[0],1).to(device)gradients=autograd.grad(outputs=d_interpolates,inputs=interpolates,grad_outputs=fake,create_graph=True,retain_graph=True,only_inputs=True,)[0]# 4. 计算梯度范数gradients=gradients.view(gradients.size(0),-1)gradient_penalty=((gradients.norm(2,dim=1)-1)**2).mean()returngradient_penalty4.2 训练循环示例
# 超参数lambda_gp=10n_critic=5# 每训练 1 次 Generator,训练 5 次 Critic# ... 初始化 DataLoader, Generator(G), Critic(D), Optimizers ...fori,(imgs,_)inenumerate(dataloader):real_imgs=imgs.to(device)batch_size=real_imgs.shape[0]# ---------------------# 训练 Critic (D)# ---------------------optimizer_D.zero_grad()# 生成假样本z=torch.randn(batch_size,latent_dim).to(device)fake_imgs=G(z)# 计算 WGAN Loss# 注意:为了使用 min 优化器,我们将最大化问题转化为最小化负数real_validity=D(real_imgs)fake_validity=D(fake_imgs)# WGAN Loss: -E[D(x)] + E[D(G(z))]d_loss=-torch.mean(real_validity)+torch.mean(fake_validity)# 计算梯度惩罚gradient_penalty=compute_gradient_penalty(D,real_imgs,fake_imgs.detach(),device)# 总 Lossd_loss=d_loss+lambda_gp*gradient_penalty d_loss.backward()optimizer_D.step()# ---------------------# 训练 Generator (G)# ---------------------# 每 n_critic 步训练一次 Gifi%n_critic==0:optimizer_G.zero_grad()# 重新生成假样本 (可选,也可以复用上面的,但为了计算图清晰通常重新生成)gen_imgs=G(z)# G 的目标是让 D 给出的分数越高越好 (即最小化 -D(G(z)))g_loss=-torch.mean(D(gen_imgs))g_loss.backward()optimizer_G.step()ifi%100==0:print(f"[Epoch{epoch}/{opt.n_epochs}] [Batch{i}/{len(dataloader)}] [D loss:{d_loss.item()}] [G loss:{g_loss.item()}]")五、总结与建议
| 特性 | GAN | WGAN | WGAN-GP |
|---|---|---|---|
| 判别器输出 | 概率 [0, 1] (Sigmoid) | 实数评分 (无 Sigmoid) | 实数评分 (无 Sigmoid) |
| Loss 函数 | JS 散度 (Log Loss) | Wasserstein 距离 (均值差) | Wasserstein 距离 + 梯度惩罚 |
| 约束方法 | 无 | Weight Clipping(参数截断) | Gradient Penalty(梯度惩罚) |
| 训练稳定性 | 差 (易模式崩塌) | 较好 | 极好 |
| 收敛速度 | 快 (但不稳定) | 较慢 | 适中 |
实际使用建议:
如果你要处理新的生成任务,首选 WGAN-GP。它几乎不需要繁琐的超参数调试就能稳定训练,而且 Loss 曲线能真实反映图像质量的提升。虽然计算梯度惩罚会稍微增加一点训练时间,但相比于原始 GAN 调参的痛苦,这是非常值得的投入。