深度学习之 GAN 进行 mnist 图片的生成
深度学习之 GAN 进行 mnist 图片的生成
mport numpy as np
import os
import codecs
import torch
from PIL import Image
import PIL
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
def extract_image(path, extract_path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
parsed = parsed.reshape(length, num_rows, num_cols)
for image_i, image in enumerate(parsed):
Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i)))
image_path = './mnist/t10k-images.idx3-ubyte'
extract_path = './mnist/data/image'
import math
def images_square_grid(images, mode):
save_size = math.floor(np.sqrt(images.shape[0]))
# Scale to 0-255
images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
# Put images in a square arrangement
images_in_square = np.reshape(
images[:save_size*save_size],
(save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
if mode == 'L':
images_in_square = np.squeeze(images_in_square, 4)
# Combine images to grid image
new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
for col_i, col_images in enumerate(images_in_square):
for image_i, image in enumerate(col_images):
im = Image.fromarray(image, mode)
new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
return new_im
def get_image(image_path, width, height, mode):
image = Image.open(image_path)
if image.size != (width, height):
face_width = face_width = 108
j = (image.size[0] - face_width) // 2
i = (image.size[1] - face_height) // 2
image = image.crop([j, i, j + face_width, i + face_height])
image = image.resize([width, height], Image.BILINEAR)
return np.array(image.convert(mode))
def get_batch(image_files, width, height, mode):
data_batch = np.array([get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
if len(data_batch.shape) < 4:
data_batch = data_batch.reshape(data_batch.shape + (1,))
return data_batch
%matplotlib inline
import os
from glob import glob
from matplotlib import pyplot
data_dir = './mnist/data'
show_n_images = 25
mnist_images = get_batch(glob(os.path.join(data_dir, 'image/*.jpg'))[:show_n_images], 28, 28, 'L')
pyplot.imshow(images_square_grid(mnist_images, 'L'), cmap='gray')
from torch.utils import data
import torchvision as tv
batch_size = 50
transforms = tv.transforms.Compose([
tv.transforms.Resize(96),
PIL.ImageOps.grayscale,
tv.transforms.ToTensor()
])
root="d:\\work\\yoho\\dl\\dl-study\\chapter8\\mnist\\data"
dataset = tv.datasets.ImageFolder(root, transform=transforms)
dataloader = data.DataLoader(dataset, batch_size, shuffle=True, num_workers=1, drop_last=True)
import torch.nn as nn
import torch.optim as optim
from torch.nn.modules import loss
from torch.autograd import Variable as V
class GNet(nn.Module):
def __init__(self, opt):
super(GNet, self).__init__()
ngf = opt["ngf"]
target = opt["target"] or 3
self.main = nn.Sequential(
nn.ConvTranspose2d( opt["nz"], ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d( ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d( ngf, target, 5, 3, 1, bias=False),
nn.Tanh()
)
def forward(self, input):
return self.main(input)
class DNet(nn.Module):
def __init__(self, opt):
super(DNet, self).__init__()
ndf = opt["ndf"]
input = opt["input"] or 3
self.main = nn.Sequential(
nn.Conv2d(input, ndf, 5, 3, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.3, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1)
lr_g = 0.01
lr_d = 0.01
ngf = 64
ndf = 64
raw_f = 1
nz = 100
d_every = 1
g_every = 5
net_g = GNet({"target": raw_f, "ngf": ngf, 'nz': nz})
net_d = DNet({"input": raw_f, "ndf": ndf})
opt_g = optim.Adam(net_g.parameters(), lr_g, betas=(0.5, 0.999))
opt_d = optim.Adam(net_d.parameters(), lr_g, betas=(0.5, 0.999))
criterion = torch.nn.BCELoss()
true_labels = V(torch.ones(batch_size))
fake_labels = V(torch.zeros(batch_size))
fix_noises = V(torch.randn(batch_size, nz, 1, 1))
noises = V(torch.randn(batch_size, nz, 1, 1))
def train():
for ii, (img, _) in enumerate(dataloader):
real_img = V(img)
if (ii + 1) % d_every == 0:
opt_d.zero_grad()
output = net_d(real_img)
loss_d = criterion(output, true_labels)
loss_d.backward()
noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
fake_img = net_g(noises)
fake_img = fake_img.detach()
fake_output = net_d(fake_img)
loss_fake_d = criterion(fake_output, fake_labels)
loss_fake_d.backward()
opt_d.step()
if (ii + 1) % g_every == 0:
opt_g.zero_grad()
noises.data.copy_(torch.randn(batch_size, nz, 1, 1))
fake_image = net_g(noises)
fake_output = net_d(fake_img)
loss_g = criterion(fake_output, true_labels)
loss_g.backward()
opt_g.step()
def print_image():
fix_fake_imgs = net_g(fix_noises)
fix_fake_imgs = fix_fake_imgs.data.view(batch_size, 96, 96, 1).numpy()
pyplot.imshow(images_square_grid(fix_fake_imgs, 'L'), cmap='gray')
epochs = 20
def main():
for i in range(epochs):
print("epoch {}".format(i))
train()
if i % 2 == 0:
print_image()
main()
注意 GAN 很慢,要使用 GPU来工作
深度学习之 GAN 进行 mnist 图片的生成的更多相关文章
- 4.keras实现-->生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)
变分自编码器(VAE,variatinal autoencoder) VS 生成式对抗网络(GAN,generative adversarial network) 两者不仅适用于图像,还可以 ...
- 【深度学习】--GAN从入门到初始
一.前述 GAN,生成对抗网络,在2016年基本火爆深度学习,所有有必要学习一下.生成对抗网络直观的应用可以帮我们生成数据,图片. 二.具体 1.生活案例 比如假设真钱 r 坏人定义为G 我们通过 ...
- 深度学习-Wasserstein GAN论文理解笔记
GAN存在问题 训练困难,G和D多次尝试没有稳定性,Loss无法知道能否优化,生成样本单一,改进方案靠暴力尝试 WGAN GAN的Loss函数选择不合适,使模型容易面临梯度消失,梯度不稳定,优化目标不 ...
- 机器学习 —— 深度学习 —— 基于DAGNN的MNIST NET
DAGNN 是Directed acyclic graph neural network 缩写,也就有向图非循环神经网络.我使用的是由MatConvNet 提供的DAGNN API.选择这套API作为 ...
- 深度学习之GAN对抗神经网络
1.结构图 2.知识点 生成器(G):将噪音数据生成一个想要的数据 判别器(D):将生成器的结果进行判别, 3.代码及案例 # coding: utf-8 # ## 对抗生成网络案例 ## # # # ...
- 【Python开发】【神经网络与深度学习】网络爬虫之图片自动下载器
python爬虫实战--图片自动下载器 之前介绍了那么多基本知识[Python爬虫]入门知识(没看的赶紧去看)大家也估计手痒了.想要实际做个小东西来看看,毕竟: talk is cheap show ...
- 为什么要用深度学习来做个性化推荐 CTR 预估
欢迎大家前往腾讯云技术社区,获取更多腾讯海量技术实践干货哦~ 作者:苏博览 深度学习应该这一两年计算机圈子里最热的一个词了.基于深度学习,工程师们在图像,语音,NLP等领域都取得了令人振奋的进展.而深 ...
- NLP+VS︱深度学习数据集标注工具、方法摘录,欢迎补充~~
~~因为不太会使用opencv.matlab工具,所以在找一些比较简单的工具. . . 一.NLP标注工具BRAT BRAT是一个基于web的文本标注工具,主要用于对文本的结构化标注,用BRAT生成的 ...
- 好书推荐计划:Keras之父作品《Python 深度学习》
大家好,我禅师的助理兼人工智能排版住手助手条子.可能非常多人都不知道我.由于我真的难得露面一次,天天给禅师做底层工作. wx_fmt=jpeg" alt="640? wx_fmt= ...
随机推荐
- php 目录处理函数
之前我们处理的全都是文件,那目录和文件夹怎么处理呢? 我们就来学习目录或者称为文件夹的处理相关函数. 处理文件夹的基本思想如下: 1.读取某个路径的时候判断是否是文件夹 2.是文件夹的话,打开指定文件 ...
- WPF-悬浮窗(类似于360)
boss要求开发一个类似于360的悬浮窗,如下图所示: 目前采用的是wpf做的客户端,之前有个winform的项目,我参考了下,完成了wpf版的悬浮窗. Height=" WindowSta ...
- java抽象类注意问题
当知道一个类的子类将不同的实现某个方法时,把该类声明为抽象类很有用,可以共用相同的父类方法,不必再定义. 抽象类和抽象方法的关系:含有抽象方法的类一定是抽象类,抽象类里不一定含有抽象方法. 抽象类存在 ...
- handsontable 事件汇总
Hook插件 afterChange (changes: Array, source: String):1个或多个单元格的值被改变后调用 changes:是一个2维数组包含row,prop,oldVa ...
- JavaScript中的execCommand
execCommand方法是执行一个对当前文档,当前选择或者给出范围的命令.处理Html数据时常用 如下格式:document.execCommand(sCommand[,交互方式, 动态参数]) , ...
- IM-iOS退出后台接受消息,app退出后台能接收到推送
App被失活状态的时候可以走苹果的APNS:但是在活跃的时候却接受不到推送! 那就用到本地推送:UILocalNotification 消息神器. 处理不好可能会有很多本地推送到来,那么问题来了要在什 ...
- 元素化设计原理及规则v1.0
一.元素设计架构 元素设计架构展示在基于元素化设计的思想下,系统各元素之间如何相互协作,并完成整个系统搭建. 架构中以Entity(数据)为中心,由Entity产生数据库表结构,并且Entity作为业 ...
- Readiness 探测 - 每天5分钟玩转 Docker 容器技术(144)
除了 Liveness 探测,Kubernetes Health Check 机制还包括 Readiness 探测. 用户通过 Liveness 探测可以告诉 Kubernetes 什么时候通过重启容 ...
- 关于css选择器中有小数点的标签获取
需求说明 因为项目中章节配置的时候有小数点,1,1.1,1.2,1.11的标题,这个时候每一行标题的id,class设置成标题号是独一无二的标记.但是,直接用js获取是获取不到的,例如$('#3.22 ...
- day1-计算机基础
第一单元 计算机组成原理 一.概念及过程 1.进行逻辑和数值高速计算的计算机器,有存储功能,能按照程序自动执行,且能够处理海量数据的现代化电子设备. 2.发展过程 数学运算:算盘,帕斯卡的齿轮装置, ...