生成对抗网络

生成对抗网络(GAN)

参考文章

GAN(生成对抗网络)的系统全面介绍(醍醐灌顶)-CSDN博客

[图解 生成对抗网络GAN 原理 超详解_gan原理图-CSDN博客](https://blog.csdn.net/DFCED/article/details/105175097#:~:text=生成式对抗网络(GAN%2C Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。 模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。,都是 神经网络,只需要是能拟合相应生成和判别的函数即可。 但实用中一般均使用深度神经网络作为 G 和 D 。 一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。)

GAN:通过对某一事物大量数据的学习,来学习总结出其在数学层面上的分布规律,构建出合理的映射函数,从而解决现实问题

GAN的介绍

基本概念

GAN的全称是Generative adversarial network,中文翻译过来就是生成对抗网络。

生成对抗网络其实是两个网络的组合:

  • 生成网络(Generator)负责生成模拟数据;
  • 判别网络(Discriminator)负责判断输入的数据是真实的还是生成的。

生成网络要不断优化自己生成的数据让判别网络判断不出来,判别网络也要优化自己让自己判断得更准确。二者关系形成对抗,因此叫对抗网络。

GAN的基本架构图

GAN网络架构概念图

image-20240927142056332

GAN网络内部架构图

image-20240927142235521

具体怎么训练的概念介绍

这是一个生成器和判别器博弈的过程。生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。
image-20240927144851832

具体生成网络和对抗网络的优化是如何实现的? **神经网络的架构和损失函数(loss function)**。神经网络架构和损失函数的定义是能够实现优化(训练)的两个基本要素。

GAN 算法中的生成器

对于生成器,输入需要一个n维度向量,输出为图片像素大小的图片。因而首先我们需要得到输入的向量。

这里的生成器可以是任意可以输出图片的模型,比如最简单的全连接神经网络,又或者是反卷积网络等。

这里输入的向量我们将其视为携带输出的某些信息,比如说手写数字为数字几,手写的潦草程度等等。由于这里我们对于输出数字的具体信息不做要求,只要求其能够最大程度与真实手写数字相似(能骗过判别器)即可。所以我们使用随机生成的向量来作为输入即可,这里面的随机输入最好是满足常见分布比如均值分布,高斯分布等。

image-20240927145830508

GAN算法中的判别器

对于判别器不用多说,往往是常见的判别器,输入为图片,输出为图片的真伪标签

image-20240927145923932

同理,判别器与生成器一样,可以是任意的判别器模型,比如全连接网络,或者是包含卷积的网络等

GAN强大之处在于能自动学习原始真实样本集的数据分布,不管这个分布多么的复杂,只要训练的足够好就可以学出来。

传统的机器学习方法,一般会先定义一个模型,再让数据去学习。

  • 比如知道原始数据属于高斯分布,但不知道高斯分布的参数,这时定义高斯分布,然后利用数据去学习高斯分布的参数,得到最终的模型。
  • 再比如定义一个分类器(如SVM),然后强行让数据进行各种高维映射,最后变成一个简单的分布SVM可以很轻易的进行二分类(虽然SVM放松了这种映射关系,但也给了一个模型,即核映射),其实也是事先知道让数据该如何映射,只是映射的参数可以学习。

以上这些方法都在直接或间接的告诉数据该如何映射,只是不同的映射方法能力不一样。

而GAN的生成模型最后可以通过噪声生成一个完整的真实数据(比如人脸),说明生成模型掌握了从随机噪声到人脸数据的分布规律。GAN一开始并不知道这个规律是什么样,也就是说GAN是通过一次次训练后学习到的真实样本集的数据分布。

image-20240927150657773

损失函数(loss function)

生成网络的损失函数

image-20240927151136682

上式中,G 代表生成网络,D 代表判别网络,H 代表交叉熵,z 是输入随机数据。 是对生成数据的判断概率,1代表数据绝对真实,0代表数据绝对虚假。 代表判断结果与1的距离。显然生成网络想取得良好的效果,那就要做到,让判别器将生成数据判别为真数据(即D(G(z))与1的距离越小越好)