GAN_代码生成
基于tensorflow
依赖
- 3个外部库:
tensorflow、numpy、matplotlib - 用
Sequential搭网络,用Model做GAN的链式模型???
超参
不会在反向传播里更新,但直接决定模型能否收敛、生成质量、训练速度
这些参数是要可重复,为了调参方便,可自动化搜索
[!IMPORTANT]
这三个参数是生成对抗网络(GAN)或深度学习模型训练中最核心的超参数,常用于控制模型的训练过程和输入数据形态,简要解析如下:
EPOCHS = 50
“EPOCHS”(迭代次数)指整个训练数据集被模型完整训练一遍的次数。
- 50表示训练时,所有训练样本会被模型“过”50次;
- 作用:太少会导致模型“欠拟合”(没学会数据规律),太多可能导致“过拟合”(只记住训练数据、泛化能力差),50是常见的中等迭代次数,适用于数据量适中、模型复杂度一般的场景。
BATCH_SIZE = 128
“BATCH_SIZE”(批次大小)指每次模型参数更新时,输入的样本数量。
- 128表示每次训练,模型会同时处理128个样本,计算这128个样本的平均误差后,再调整一次参数;
- 作用:平衡训练效率和稳定性——批次太小会导致参数更新震荡(不稳定)、训练慢;批次太大可能占用过多显存,且单次更新对整体数据的代表性下降,128是深度学习中兼顾效率与稳定性的常用值。
NOISE_DIM = 100
“NOISE_DIM”(噪声维度)是生成模型(如GAN的生成器)的核心输入参数,仅用于需要“从随机噪声生成数据”的场景(如生成图片、文本等)。
- 100表示生成器的输入是一个“100维的随机向量”(如符合正态分布的随机数);
- 作用:这个高维噪声向量是生成器的“创意源头”,模型通过学习将100维的无序噪声,映射为有序的目标数据(如一张256x256的图片),100是生成类模型中常见的噪声维度(维度太低会限制生成多样性,太高会增加模型学习难度)。
数据
设置函数load_data()
- 相当于无监督学习,不用
y_train(标签) - 归一化到[-1,1],因为生成器最后一层是
tanh,输出范围也是[-1,1]
[!IMPORTANT]
这个函数
load_data()的作用是加载并预处理MNIST手写数字数据集,主要步骤解析如下:
通过
mnist.load_data()加载数据集,该函数返回格式为(训练集, 测试集),其中每个集合又包含(图像数据, 标签)。这里用(x_train, _), (_, _)只保留训练集的图像数据x_train,忽略训练集标签和测试集的所有数据(用_表示不需要的变量)。对
x_train进行归一化处理:(x_train.astype("float32") - 127.5) / 127.5。MNIST图像像素值范围是0-255,减去127.5再除以127.5后,像素值会被映射到[-1, 1]区间,这是很多生成模型(如GAN)常用的预处理方式。重塑数据形状:
x_train.reshape(-1, 784)。MNIST图像原始形状是28×28的二维数组,这里将其展平为784个元素的一维数组(-1表示自动计算该维度的大小,保持样本数量不变),便于后续输入模型进行处理。最终函数返回预处理后的训练集图像数据,可直接用于模型训练。
模型构件
build_generator()
用Sequential()创建模型,再通过.compile(loss, optimizor)(参数了解一下),做预准备
[!IMPORTANT]
这段代码定义了一个生成对抗网络(GAN)中的生成器模型,主要功能是将随机噪声转换为类似目标数据(结合输出维度784推测可能是28x28像素的灰度图像,如MNIST数据集)的伪造数据。以下是简要解析:
模型结构:
- 采用
Sequential顺序模型,由4个全连接层(Dense)组成- 输入为维度
NOISE_DIM的随机噪声(噪声向量)- 前3层使用
LeakyReLU激活函数(alpha=0.2,控制负斜率),逐步提升维度(256→512→1024),增强特征表达能力- 输出层为784维(对应28x28图像展平后的维度),使用
tanh激活函数,将输出值映射到[-1, 1]范围(符合图像数据归一化常见处理)编译配置:
.compile()
- 损失函数采用
binary_crossentropy(二元交叉熵),这是GAN中生成器与判别器对抗训练的常用损失- 优化器为
Adam,学习率0.0002,beta_1=0.5,是GAN训练中经过实践验证的较优参数组合(较小的beta_1使动量估计更关注近期梯度)整体来看,该生成器通过多层全连接网络对噪声进行非线性变换,最终生成与目标数据分布相似的输出,是典型的GAN生成器架构。
100维噪音 ==> 256 ==> 512 ==> 1024 ==> 784(假图)
build_discriminator()
同样的方式,参数内容不同
多了个Dropout(0.3)来防止过拟合,
[!IMPORTANT]
这段代码定义了一个用于生成对抗网络(GAN)中的判别器模型,功能是判断输入数据是真实样本还是生成器生成的假样本,具体解析如下:
模型结构:
- 采用
Sequential顺序模型,由多个全连接层(Dense)堆叠而成- 输入层为
Dense(1024, input_dim=784),说明输入数据是784维(对应MNIST数据集的28×28像素图像展平后的数据),输出1024维特征- 中间层依次为512维、256维的全连接层,每层后均使用
LeakyReLU(alpha=0.2)激活函数(解决ReLU的死亡神经元问题,小斜率0.2允许少量负输入通过)- 各层间加入
Dropout(0.3),随机丢弃30%的神经元,防止过拟合- **输出层为
Dense(1, activation="sigmoid")**,输出单个0-1之间的值(1表示判断为真实样本,0表示判断为假样本)模型配置:
- 使用
binary_crossentropy(二元交叉熵)损失函数,适合二分类任务(真实/假样本判断)- 优化器为
Adam,学习率0.0002,beta_1=0.5(GAN中常用的参数设置,有助于稳定训练)整体来看,这是一个典型的GAN判别器架构,通过多层全连接网络提取特征,最终输出对输入数据真实性的判断概率。
build_gan
作为“生成器部分”的训练通路
噪音 ==> 生成器 ==> 假图 ==> 判别器 ==> 分数,最后
Model生成gan,并.compile设置判别器
trainable=False保证反向传播只更新生成器权重
[!IMPORTANT]
这段代码定义了一个构建生成对抗网络(GAN)的函数
build_gan,主要作用是将生成器(gen)和判别器(disc)组合成一个完整的GAN模型,用于训练生成器。以下是简要解析:
冻结判别器参数:
disc.trainable = False
这一步是关键,在训练GAN(即训练生成器时),需要固定判别器的参数,避免其在生成器的训练过程中被更新,确保生成器的优化目标是“欺骗”当前状态的判别器。定义GAN的输入和输出:
gan_input = Input(shape=(NOISE_DIM,)):定义GAN的输入为噪声向量(维度为NOISE_DIM),这也是生成器的输入。gan_output = disc(gen(gan_input)):将噪声输入生成器得到伪造样本,再将伪造样本输入判别器,得到判别器的输出(即对伪造样本的“真假判断”),这是GAN的最终输出。构建并编译GAN模型:
gan = Model(gan_input, gan_output):以噪声为输入、判别器对伪造样本的判断为输出,构建GAN模型。- 编译时使用二元交叉熵损失(
binary_crossentropy)和Adam优化器(学习率0.0002,动量参数beta_1=0.5,这是GAN训练中常用的参数设置)。整体来看,这个函数的核心是将生成器和判别器串联,通过冻结判别器,让生成器在训练时专注于学习如何生成能“骗过”判别器的样本。
??????这里有个疑惑,为什么要冻结disc,——这里不包含参数的更新;在train()函数中,有对应的模块进行训练disc、gen,这一步的操作意义好像没有,所以,对这一步不是很理解
看下面训练的操作可知了
可视化
save_images()函数
- 10个epoch存一张图
matplotlib.use("Agg")保证在纯命令行服务器也能跑- 同时也存图
[!IMPORTANT]
这段代码定义了一个用于保存生成模型(如GAN中的生成器)生成图像的函数
save_images,主要功能解析如下:
参数说明:
gen:生成模型(生成器),用于根据输入噪声生成图像epoch:当前训练轮次,用于命名保存的图片examples:生成图像的数量(默认25张)dim:图像排列的行列数(默认5x5网格,对应25张图)figsize:画布大小(默认6x6)核心逻辑:
- 生成符合正态分布的随机噪声(均值0,标准差1),形状为
(examples, NOISE_DIM)(NOISE_DIM是噪声向量维度)- 用生成器
gen对噪声进行预测,得到生成图像,并将其reshape为(examples, 28, 28)(符合MNIST等手写数字数据集的图像尺寸)- 使用matplotlib创建画布,按
dim指定的网格排列图像- 每张图像以灰度图(
cmap="gray")显示,关闭坐标轴- 调整布局后,将图像保存到
RESULT_DIR目录,文件名格式为generated_epoch_xxx.png(xxx为3位数字的轮次)- 关闭画布,避免内存占用
整体作用是在模型训练过程中,定期保存生成器的输出结果,方便观察模型生成效果随训练的变化。
train()
整合、循环训练
1 | for epoch in 1~50: |
one判别器、one生成器,交替进行
train过程中训练方式:
- gen模型仅是用来生成fake_imgs
- disc模型用来判断,会进行train_on_batch,两次,对应于real和fake
- gan模型虽然是包含了上面两个模型,但在其开始会disc.trainable=False,,所以链式模型
gan的train_on_batch只更新生成器 G 的权重(链式模型 gan 天然只反向到 G,因为在build_gan时,通过disc.trainable = False实现不把D的权重放进gan的**trainable_weights**)——呼应build_gan的False操作
G 永远开放,D 动态上锁;
gan 只反到 G,权重列表作证。
[!IMPORTANT]
这段代码是一个典型的生成对抗网络(GAN)训练流程,主要逻辑解析如下:
数据准备:
- 加载训练数据
x_train,计算数据量和每个epoch的训练步数(总数据量//批次大小)模型构建:
- 初始化生成器(
gen)、判别器(disc)- 构建GAN整体模型(
gan),将生成器和判别器结合标签设置:
- 真实样本标签设为1(
real_labels)- 生成样本标签设为0(
fake_labels)训练循环:
- 外层循环按epoch迭代(从1到EPOCHS)
- 内层循环按步数迭代(每个epoch包含steps_per_epoch步)
判别器训练:
- 生成随机噪声,通过生成器得到假样本(
fake_imgs)- 从训练集中随机抽取真实样本(
real_imgs)- 开启判别器训练模式,分别用真实样本和假样本训练,得到两个损失(
d_loss_real和d_loss_fake)生成器训练:
- 冻结判别器参数(
disc.trainable = False)- 用噪声作为输入,以真实标签(1)为目标训练生成器,让判别器无法区分生成样本
训练监控与保存:
- 每个epoch结束打印损失值
- 第1个epoch和每10个epoch保存生成的图片
- 训练结束后保存生成器和判别器的权重
核心思想是通过生成器和判别器的对抗训练(生成器试图生成逼真样本欺骗判别器,判别器试图区分真假样本),最终使生成器能够生成高质量的逼真样本。
| 层级 | 决定的动作 | 直接影响的量(已有参数) | 典型代码位置 |
|---|---|---|---|
| step(微观) | 1. 计算梯度 2. 执行权重更新 |
全部模型权重 W, b |
optimizer.apply_gradients()model.train_on_batch() |
| epoch(宏观) | 1. 学习率调度 2. 早停 3. 保存权重 4. 打印/可视化 |
- 学习率 lr(调度器)- 最优权重副本 best_model.*- 日志、图像、检查点 |
ReduceLROnPlateauModelCheckpointsave_images()print(f"Epoch {e} loss={loss:.4f}") |
step 决定“权重什么时候改”;epoch 决定“学习率什么时候降、权重什么时候存、日志什么时候打”
运行
train()
1 | 真图 ──┐ |
生成两个.h5权重文件
和一些过程数字图
1 | """ |
gen、disc生成用的是Sequential
gan用的是Model
1 | import、设参 |
基于PyTorch
1 | import torch |
前提
模块导入
1 |
torch:PyTorch核心torchvision:听MNIST数据集和图像变换matplotlib:可视化
设置随机种子&设备
1 | torch.manual_seed(42) |
零件
生成器
1 | class Generator(nn.Module):#继承 |
DCGAN 设计准则:
生成器中,每个转置卷积后都跟 BatchNorm + ReLU(最后一层除外)。
归一化处理,再运行激活函数
PyTorch 内部会:
- 检查
generator是nn.Module的子类; - 自动调用
generator.forward(noise); - 同时记录计算图(computation graph),以便后续反向传播(
.backward())。
判别器
1 | # 判别器网络 |
预备
超参
1 | # 这些参数的选择对GAN的训练稳定性和生成效果有重要影响 |
处理
1 | # 数据加载和预处理 |
DataLoader(...)
- 作用:将数据集封装成可迭代的批次(batches),支持多线程、打乱等。
- 参数详解:
dataset:前面创建的MNIST对象。batch_size=batch_size(如 64):每次迭代返回 64 张图像 + 标签。shuffle=True:每个 epoch 打乱数据顺序,防止模型记住 batch 顺序,提升泛化能力。
.to(device)
将模型参数(weights)和缓冲区(如 BN 的 running_mean)移动到指定设备(CPU / CUDA / MPS)。
optim.Adam
generator.parameters():获取生成器所有可训练参数(weights + biases)lr=0.0002:学习率。GAN 通常用较小 lr(如 0.0002),避免训练震荡。betas=(beta1, 0.999)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
:
- `beta1=0.5`(你在前面定义):**一阶动量衰减率**(默认 0.9,但 DCGAN 推荐 0.5)
- `0.999`:二阶动量衰减率(默认值)
## 训练
```python
def train():
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):#数据加载器返回batch
batch_size = real_images.size(0)#获取氮气batch实际大小
real_images = real_images.to(device)#
# 训练判别器
# 判别器需要学会区分真实图像(标签为1)和生成图像(标签为0)
d_optimizer.zero_grad() # 清空判别器的梯度
label_real = torch.ones(batch_size, 1).to(device) # 真实图像的标签为1
label_fake = torch.zeros(batch_size, 1).to(device) # 生成图像的标签为0
# 计算判别器对真实图像的损失
output_real = discriminator(real_images) # 判别器对真实图像的预测结果
d_loss_real = criterion(output_real, label_real) # 计算真实图像的二元交叉熵损失
# 生成假图像并计算判别器的损失
noise = torch.randn(batch_size, latent_dim).to(device) # 生成随机噪声
fake_images = generator(noise) # 使用生成器生成假图像
output_fake = discriminator(fake_images.detach()) # detach()防止梯度传递到生成器
d_loss_fake = criterion(output_fake, label_fake) # 计算假图像的二元交叉熵损失
# 计算判别器的总损失并更新参数
d_loss = d_loss_real + d_loss_fake # 判别器总损失是真假图像损失之和
d_loss.backward() # 反向传播计算梯度
d_optimizer.step() # 更新判别器参数
# 训练生成器
# 生成器的目标是生成能够欺骗判别器的图像
g_optimizer.zero_grad() # 清空生成器的梯度
output_fake = discriminator(fake_images) # 判别器对生成图像的预测
# 生成器的损失:希望判别器将生成的图像判断为真实图像
g_loss = criterion(output_fake, label_real) # 使用真实标签计算损失
g_loss.backward() # 反向传播计算梯度
g_optimizer.step() # 更新生成器参数
if i % 100 == 0:
print(f'Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
f'd_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f}')
# 每个epoch保存生成的图像样本
if (epoch + 1) % 10 == 0:
save_fake_images(epoch + 1)
#计算损失,.backward()反向传播,x_optimizer.step()更新参数
x_optimizer.zero_grad()
PyTorch 默认累加梯度(用于梯度裁剪等),但每次更新前必须清零,否则梯度会错误累积。
损失函数criterion,来计算差异,同时进行反向传播计算梯度,更新参数
save
1 | def save_fake_images(epoch): |
loss
一个标量张量
通过损失函数来生成loss(差距)
d_loss_real = criterion(output_real, label_real)——判别器输出接近1(真),loss近似0(label_real=1);反之,近似1
d_loss_fake = criterion(output_fake, label_fake)——判别器输出接近0(假),loss近似0(label_fake=0);反之,近似1
g_loss = criterion(output_fake, label_real)——同样,loss近似为0时,说明很逼真了
代码
1 | import torch |
代码参考链接使用PyTorch实现MNIST数据集的GAN网络_gan mnist-CSDN博客
(代码有变动)




