当前位置: 首页>編程日記>正文

对抗神经网络学习和实现(GAN)

对抗神经网络学习和实现(GAN)

一,GAN的原理介绍

\quadGAN的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator)和D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:
∙\bulletG是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。
∙\bulletD是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
\quad在训练过程中**,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来**。这样,G和D构成了一个动态的“博弈过程”。
\quad最后博弈的结果是什么?在最理想的状态下,G可以生成足以“以假乱真”的图片G(z)。对于D来说,它难以判定G生成的图片究竟是不是真实的,因此D(G(z)) = 0.5。
\quad这样我们的目的就达成了:我们得到了一个生成式的模型G,它可以用来生成图片。
\quad以上只是大致说了一下GAN的核心原理,如何用数学语言描述呢?这里直接摘录论文里的公式:
在这里插入图片描述

简单分析一下这个公式:
∙\bullet整个式子由两项构成。x表示真实图片,z表示输入G网络的噪声,而G(z)表示G网络生成的图片。
∙\bulletD(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。而D(G(z))是D网络判断G生成的图片的是否真实的概率
∙\bulletG的目的:上面提到过,D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”。也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小。因此我们看到式子的最前面的记号是min_G
∙\bulletD的目的:D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此式子对于D来说是求最大(max_D)

下面这幅图片很好地描述了这个过程:
在这里插入图片描述
那么如何用随机梯度下降法训练D和G?论文中也给出了算法:
在这里插入图片描述
这里红框圈出的部分是我们要额外注意的。第一步我们训练D,D是希望V(G,`D)越大越好,所以是加上梯度(ascending)。第二步训练G时,V(G, D)越小越好,所以是减去梯度(descending)。整个训练过程交替进行。
\quad这个直接看代码就有很好的体现了。

二,DCGAN的原理介绍

\quad我们知道深度学习中对图像处理应用最好的模型是CNN,那么如何把CNN与GAN结合?DCGAN是这方面最好的尝试之一,文章最后会附上论文地址:
\quadDCGAN的原理和GAN是一样的,这里就不在赘述。它只是把上述的G和D换成了两个卷积神经网络(CNN)。但不是直接换就可以了,DCGAN对卷积神经网络的结构做了一些改变,以提高样本的质量和收敛的速度,这些改变有:
∙\bullet取消所有pooling层。G网络中使用转置卷积(transposed convolutional layer)进行上采样,D网络中用加入stride的卷积代替pooling。
∙\bullet在D和G中均使用batch normalization
∙\bullet去掉FC层,使网络变为全卷积网络
∙\bulletG网络中使用ReLU作为激活函数,最后一层使用tanh
∙\bulletD网络中使用LeakyReLU作为激活函数
\quadDCGAN中的G网络示意:
在这里插入图片描述
\quad直接利用1个全连接层实现的GAN如下,补充:一开始的时候生成出来的图像中含有负数以至于无法显示,解决方案是在generate的输出激活函数中不要使用tanh,改为使用sigmoid即可。

\quad训练的过程如下:
这里写图片描述
训练完成之后生成了20张随机产生的图片,是不是足够以假乱真啦。

#coding=utf-8
import tensorflow as tf
import tflearn
import tflearn.datasets.mnist as mnist
import matplotlib.pyplot as plt
import numpy as np
X, Y, X_test, Y_test = mnist.load_data()
img_dim = 784
z_dim = 200
total_sample = len(X)#构建生成器和判别器
def generate(x, reuse=tf.AUTO_REUSE):with tf.variable_scope('Generate', reuse=reuse):x = tflearn.fully_connected(x,256,activation='relu')x = tflearn.fully_connected(x,img_dim,activation='sigmoid')return x
def discriminator(x, reuse=tf.AUTO_REUSE):with tf.variable_scope('Discriminator', reuse=reuse):x = tflearn.fully_connected(x, 256, activation='relu')x = tflearn.fully_connected(x, 1, activation='sigmoid')return x#构建网络gen_input = tflearn.input_data(shape=[None,z_dim], name='input_noise')
disc_input = tflearn.input_data(shape=[None,784], name='disc_input')#生成器,判别器
gen_sample = generate(gen_input)
disc_real = discriminator(disc_input) #判别网络
disc_fake = discriminator(gen_sample) #欺骗网络Ddisc_loss = -tf.reduce_mean(tf.log(disc_real)+tf.log(1. -disc_fake))
gen_loss = -tf.reduce_mean(tf.log(disc_fake))gen_vars = tflearn.get_layer_variables_by_scope('Generate')
gen_model = tflearn.regression(gen_sample, placeholder=None, optimizer='adam', loss=gen_loss, trainable_vars=gen_vars,batch_size=64,name='target_gen',op_name='GEN')disc_vars = tflearn.get_layer_variables_by_scope('Discriminator')
disc_model = tflearn.regression(disc_real,placeholder=None,optimizer='adam',loss=disc_loss,trainable_vars=disc_vars,batch_size=64,name='target_disc',op_name='DISC')gan = tflearn.DNN(gen_model)#训练并绘制图像
z = np.random.uniform(-1.,1.,[total_sample,z_dim])
gan.fit(X_inputs={gen_input: z,disc_input: X},Y_targets=None,n_epoch=200)f, a = plt.subplots(2,10,figsize=(10,4))
for i in range(10):for j in range(2):#Noise inputz = np.random.uniform(-1.,1.,size=[1,z_dim])#Generate image from noise. Extend to 3 channels for matplot figure.temp = [[temp,temp,temp] for temp in list(gan.predict([z])[0])]print(temp)a[j][i].imshow(np.reshape(temp,(28,28,3)))
f.show()
plt.show()

代码2:https://github.com/wiseodd/generative-models

GAN的TF实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import osdef xavier_init(size):in_dim = size[0]xavier_stddev = 1. / tf.sqrt(in_dim / 2.)return tf.random_normal(shape=size, stddev=xavier_stddev)X = tf.placeholder(tf.float32, shape=[None, 784])D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))theta_D = [D_W1, D_W2, D_b1, D_b2]Z = tf.placeholder(tf.float32, shape=[None, 100])G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))theta_G = [G_W1, G_W2, G_b1, G_b2]def sample_Z(m, n):return np.random.uniform(-1., 1., size=[m, n])def generator(z):G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)G_log_prob = tf.matmul(G_h1, G_W2) + G_b2G_prob = tf.nn.sigmoid(G_log_prob)return G_probdef discriminator(x):D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)D_logit = tf.matmul(D_h1, D_W2) + D_b2D_prob = tf.nn.sigmoid(D_logit)return D_prob, D_logitdef plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return figG_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))# Alternative losses:
# -------------------
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)mb_size = 128
Z_dim = 100mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)sess = tf.Session()
sess.run(tf.global_variables_initializer())if not os.path.exists('out/'):os.makedirs('out/')i = 0for it in range(1000000):if it % 1000 == 0:samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})fig = plot(samples)plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')i += 1plt.close(fig)X_mb, _ = mnist.train.next_batch(mb_size)_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})if it % 1000 == 0:print('Iter: {}'.format(it))print('D loss: {:.4}'. format(D_loss_curr))print('G_loss: {:.4}'.format(G_loss_curr))
print()

DCGAN待续:

参考链接:
https://zhuanlan.zhihu.com/p/24767059
http://blog.csdn.net/twt520ly/article/details/79420597
https://zhuanlan.zhihu.com/p/27295635
https://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/

维护了一个微信公众号,分享论文,算法,比赛,生活,欢迎加入。

在这里插入图片描述


https://www.fengoutiyan.com/post/13757.html

相关文章:

  • 鏡像模式如何設置在哪,圖片鏡像操作
  • 什么軟件可以把圖片鏡像翻轉,C#圖片處理 解決左右鏡像相反(旋轉圖片)
  • 手機照片鏡像翻轉,C#圖像鏡像
  • 視頻鏡像翻轉軟件,python圖片鏡像翻轉_python中鏡像實現方法
  • 什么軟件可以把圖片鏡像翻轉,利用PS實現圖片的鏡像處理
  • 照片鏡像翻轉app,java實現圖片鏡像翻轉
  • 什么軟件可以把圖片鏡像翻轉,python圖片鏡像翻轉_python圖像處理之鏡像實現方法
  • matlab下載,matlab如何鏡像處理圖片,matlab實現圖像鏡像
  • 圖片鏡像翻轉,MATLAB:鏡像圖片
  • 鏡像翻轉圖片的軟件,圖像處理:實現圖片鏡像(基于python)
  • canvas可畫,JavaScript - canvas - 鏡像圖片
  • 圖片鏡像翻轉,UGUI優化:使用鏡像圖片
  • Codeforces,CodeForces 1253C
  • MySQL下載安裝,Mysql ERROR: 1253 解決方法
  • 勝利大逃亡英雄逃亡方案,HDU - 1253 勝利大逃亡 BFS
  • 大一c語言期末考試試題及答案匯總,電大計算機C語言1253,1253《C語言程序設計》電大期末精彩試題及其問題詳解
  • lu求解線性方程組,P1253 [yLOI2018] 扶蘇的問題 (線段樹)
  • c語言程序設計基礎題庫,1253號C語言程序設計試題,2016年1月試卷號1253C語言程序設計A.pdf
  • 信奧賽一本通官網,【信奧賽一本通】1253:抓住那頭牛(詳細代碼)
  • c語言程序設計1253,1253c語言程序設計a(2010年1月)
  • 勝利大逃亡英雄逃亡方案,BFS——1253 勝利大逃亡
  • 直流電壓測量模塊,IM1253B交直流電能計量模塊(艾銳達光電)
  • c語言程序設計第三版課后答案,【渝粵題庫】國家開放大學2021春1253C語言程序設計答案
  • 18轉換為二進制,1253. 將數字轉換為16進制
  • light-emitting diode,LightOJ-1253 Misere Nim
  • masterroyale魔改版,1253 Dungeon Master
  • codeformer官網中文版,codeforces.1253 B
  • c語言程序設計考研真題及答案,2020C語言程序設計1253,1253計算機科學與技術專業C語言程序設計A科目2020年09月國家開 放大學(中央廣播電視大學)
  • c語言程序設計基礎題庫,1253本科2016c語言程序設計試題,1253電大《C語言程序設計A》試題和答案200901
  • 肇事逃逸車輛無法聯系到車主怎么辦,1253尋找肇事司機