玖叶教程网

前端编程开发入门

人工智能基础之-MATLAB生成对抗网络系列(持续更新)

因为之前用生成对抗网络及众多变体生成诸如心电信号,肌电信号,脑电信号,微震信号,机械振动信号,雷达信号等,但生成的信号在频谱或者时频谱上表现很差,所以暂时先不涉及到这些复杂信号,仅仅以手写数字图像为例进行说明,因为Python相关的资源太多了,我就不凑热闹了,使用的编程环境为MALAB R2021B。

首先看一下对抗自编码器AAE(Adversarial AutoEncoder),关于AAE的大致理解,可以查看如下文章

AAE(Adversarial Autoencoders)浅解 - 嘎嘎小鱼仔的文章 - 知乎 https://zhuanlan.zhihu.com/p/382958740

AAE根据变分自编码器VAE发展而来,其发展之处就在于加入了对抗的思想

上半部分就是一个简单典型的自编码器AE结构,包含输入层input layer,编码层encoder layer, 隐层hidden layer, 解码层decoder layer , 输出层output layer。encoder把真实分布x映射为隐层z, decoder 再将z解码还原成x。AAE的特点就在于在隐层hidden layer中引入了对抗的思想来优化隐层的z,判别器discriminator 需要在隐层判断采样后的真实数据和生成器encoder所产生的假数据。因此discriminator的目的就是使得q(z | x) 不断向p(z)靠近。

Adversarial Autoencoders论文链接:https://arxiv.org/abs/1511.0564

下面直接上代码

首先,导入相关的mnist手写数字图

load('mnistAll.mat')

然后对训练、测试图像进行预处理

trainX = preprocess(mnist.train_images); trainY = mnist.train_labels;%训练标签testX = preprocess(mnist.test_images); testY = mnist.test_labels;%测试标签

preprocess为归一化函数,如下

function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end

然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等

settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1]; 
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;

下面进行编码器初始化,代码还是很容易看懂的

paramsEn.FCW1 = dlarray(initializeGaussian([512,...
     prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));

解码器初始化

paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));

判别器初始化

paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));

%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];

开始训练

dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~out
    tic; 
    shuffleid = randperm(size(trainX,2));
    trainXshuffle = trainX(:,shuffleid);
    fprintf('Epoch %d\n',epoch) 
    for i=1:numIterations
        global_iter = global_iter+1;
        idx = (i-1)*settings.batch_size+1:i*settings.batch_size;
        XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');

        [GradEn,GradDe,GradDis] = ...
                dlfeval(@modelGradients,XBatch,...
                paramsEn,paramsDe,paramsDis,settings);

        % 更新判别器网络参数
        [paramsDis,avgG.Dis,avgGS.Dis] = ...
            adamupdate(paramsDis, GradDis, ...
            avgG.Dis, avgGS.Dis, global_iter, ...
            settings.lrD, settings.beta1, settings.beta2);

        % 更新编码器网络参数
        [paramsEn,avgG.En,avgGS.En] = ...
            adamupdate(paramsEn, GradEn, ...
            avgG.En, avgGS.En, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        % 更新解码器网络参数
        [paramsDe,avgG.De,avgGS.De] = ...
            adamupdate(paramsDe, GradDe, ...
            avgG.De, avgGS.De, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        if i==1 || rem(i,20)==0
            progressplot(paramsDe,settings);
            if i==1 
                h = gcf;
                % 捕获图像
                frame = getframe(h); 
                im = frame2im(frame); 
                [imind,cm] = rgb2ind(im,256); 
                % 写入 GIF 文件
                if epoch == 0
                  imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf); 
                else 
                  imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append'); 
                end 
            end
        end
        
    end

    elapsedTime = toc;
    disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")
    epoch = epoch+1;
    if epoch == settings.maxepochs
        out = true;
    end    
end

下面是完整的辅助函数

模型的梯度计算函数

function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...
    dly(settings.latent_dim+1:2*settings.latent_dim)*...
    randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');

%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));

%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));

%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end

提取数据函数

function x = gatext(x)
x = gather(extractdata(x));
end

GPU深度学习数组wrapper函数

function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end

权重初始化函数

function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2
    sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end

dropout函数

function dly = dropout(dlx,p)
if nargin < 2
    p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end

编码器函数

function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end

解码器函数

function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end

判别器函数

function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end

动态进度图

function progressplot(paramsDe,settings)
r = 5; c = 5;
noise = gpdl(randn([settings.latent_dim,r*c]),'CB');
gen_imgs = Decoder(noise,paramsDe);
gen_imgs = reshape(gen_imgs,28,28,[]);

fig = gcf;
if ~isempty(fig.Children)
    delete(fig.Children)
end

I = imtile(gatext(gen_imgs));
I = rescale(I);
imagesc(I)
title("Generated Images")
colormap gray

drawnow;
end

最后,看一下生成的GIF动态图

?以后会讲

(1)辅助分类器生成对抗网络Auxiliary Classifier Generative Adversarial Network

?(2)条件生成对抗网络Conditional Generative Adversarial Network

?(3)深层卷积生成对抗网络Deep Convolutional Generative Adversarial Network

?(4)最基础的生成对抗网络Basic Generative Adversarial Network

?(5)Info Generative Adversarial Network

?(6)最小二乘生成对抗网络Least Squares Generative Adversarial Network

?(7)著名的Pixels-to-Pixels

?(8)半监督生成对抗网络Semi-Supervised Generative Adversarial Network

?(9)著名的Wasserstein Generative Adversarial Network

?

?

相应的参考文献如下

  • Y. LeCun and C. Cortes, “MNIST handwritten digitdatabase,” 2010. [MNIST]
  • J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, andL. Fei-Fei, “ImageNet: A Large-Scale Hierarchical Image Database,” inCVPR09, 2009. [Apple2Orange (ImageNet)]
  • R. Tyle?ek and R. ?ára, “Spatial pattern templates forrecognition of objects with regular structure,” inProc.GCPR, (Saarbrucken, Germany), 2013. [Facade]
  • Z. Liu, P. Luo, X. Wang, and X. Tang, “Deep learn-ing face attributes in the wild,” inProceedings of In-ternational Conference on Computer Vision (ICCV),December 2015. [CelebA]
  • Goodfellow, Ian J. et al. “Generative Adversarial Networks.” ArXiv abs/1406.2661 (2014): n. pag. (GAN)
  • Radford, Alec et al. “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.” CoRR abs/1511.06434 (2015): n. pag. (DCGAN)
  • Denton, Emily L. et al. “Semi-Supervised Learning with Context-Conditional Generative Adversarial Networks.” ArXiv abs/1611.06430 (2017): n. pag. (CGAN)
  • Odena, Augustus et al. “Conditional Image Synthesis with Auxiliary Classifier GANs.” ICML (2016). (ACGAN)
  • Chen, Xi et al. “InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.” NIPS (2016). (InfoGAN)
  • Makhzani, Alireza et al. “Adversarial Autoencoders.” ArXiv abs/1511.05644 (2015): n. pag. (AAE)
  • Isola, Phillip et al. “Image-to-Image Translation with Conditional Adversarial Networks.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016): 5967-5976. (Pix2Pix)
  • J.-Y. Zhu, T. Park, P. Isola, and A. A. Efros, “Unpairedimage-to-image translation using cycle-consistent ad-versarial networks,” 2017. (CycleGAN)
  • Arjovsky, Martín et al. “Wasserstein GAN.” ArXiv abs/1701.07875 (2017): n. pag. (WGAN)
  • Odena, Augustus. “Semi-Supervised Learning with Generative Adversarial Networks.” ArXiv abs/1606.01583 (2016): n. pag. (SGAN)

所有的代码链接如下

https://mianbaoduo.com/o/bread/Y5eVkpZq

第三方面包多直接下载

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言