Diffusers实战
Smiling & Weeping
---- 一生拥有自由和爱,是我全部的野心
1. 环境准备
%pip install diffusers
from huggingface_hub import notebook_login # 登录huggingface
notebook_login()
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torchvision
from PIL import Image
def show_images(x):
"""给定一批图像,创建一个网格并将其转换成PIL"""
x = x*0.5 + 0.5
grid = torchvision.utils.make_grid(x)
grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1)*255
grad_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
return grad_im
def make_grid(images, size=64):
"""给定一个PIL图像列表,将他们叠加成一行以便查看"""
output_im = Image.new("RGB", (size*len(images), size))
for i, im in enumerate(images):
out_im.paste(im.resize((size, size)), (i*size, 0))
return output_im
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
from diffusers import DDPMPipeline, StableDiffusionPipeline model_id = "sd-dreambooth-library/mr-potato-head"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
prompt = "a cute anime characters using 8K resolution"
image = pipe(prompt, num_inference_steps=50, guidance_scale=5.5).images[0]
image
Diffusers核心API:
- 管线:从高层次设计的多种类函数,便于部署的方式实现,能够快速利用预训练的主流扩散模型来生成样本。
- 模型:在训练新的扩散模型时需要用到的网络结构。
- 调度器:在推理过程中使用多种不同的技巧来从噪声中生成图像,同时可以生成训练过程中所需的“带噪”图像。
import torchvision
from datasets import load_dataset
from torchvision import transforms
from diffusers import DDPMScheduler
from diffusers import DDPMPipeline, StableDiffusionPipeline dataset = load_dataset('lowres/anime', split="train") image_size = 256
batch_size = 8 preprocess = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.5], [0.5]),
]) def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images} dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
xb = next(iter(train_dataloader))['images'].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8*256, 256), resample=Image.NEAREST)
# 定义调度器
from diffusers import DDPMScheduler noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.rand_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noise X Shape", noisy_xb.shape)
show_images(noisy_xb).resize((8*64, 64), resample=Image.NEAREST)
from diffusers import UNet2DModel model = UNet2DModel(
sample_size=image_size, # 目标图像的分辨率
in_channels=3,
out_channels=3,
layers_per_block=2, # 每一个UNet块中的ResNet层数
block_out_channels=(64, 128, 128, 256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # 带有空域维度的self-att的ResNet下采样模块
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # 带有空域维度的self-att的ResNet上采样模块
"UpBlock2D",
"UpBlock2D",
),
) model = model.to(device)
with torch.no_grad():
model_pred = model(noisy_xb, timesteps).sample model_pred.shape
训练
# 设定噪声调度器
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2") # 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) losses = [] # 定义损失函数
loss_fn = torch.nn.MSELoss() for epoch in range(45):
for step, batch in enumerate(train_dataloader):
# 未添加噪声的数据(clean data)
clean_data = batch['images'].to(device) # 生成噪声
noise = torch.randn(clean_data.shape).to(device)
bs = clean_data.shape[0] # 为每张图片随机采样一个时间步
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs, ), device=device).long() # 噪声数据
# 根据每个时间步的噪声幅度(迭代次数),向清晰的图片中添加噪声
noisy_data = noise_scheduler.add_noise(clean_data, noise, timesteps) # 获得预测模型
pred_data = model(noisy_data, timesteps, return_dict=False)[0] # 计算损失
loss = loss_fn(pred_data, clean_data)
loss.backward()
losses.append(loss.item()) # 迭代模型参数
optimizer.step()
optimizer.zero_grad() if (epoch+1) % 5 == 0:
loss_last_epoch = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
print(f"Epoch: {epoch+1}, loss: {loss_last_epoch}")
torch.save(model.state_dict(), 'save.pt')
绘制损失图线
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()
Diffusers实战的更多相关文章
- SSH实战 · 唯唯乐购项目(上)
前台需求分析 一:用户模块 注册 前台JS校验 使用AJAX完成对用户名(邮箱)的异步校验 后台Struts2校验 验证码 发送激活邮件 将用户信息存入到数据库 激活 点击激活邮件中的链接完成激活 根 ...
- GitHub实战系列汇总篇
基础: 1.GitHub实战系列~1.环境部署+创建第一个文件 2015-12-9 http://www.cnblogs.com/dunitian/p/5034624.html 2.GitHub实战系 ...
- MySQL 系列(四)主从复制、备份恢复方案生产环境实战
第一篇:MySQL 系列(一) 生产标准线上环境安装配置案例及棘手问题解决 第二篇:MySQL 系列(二) 你不知道的数据库操作 第三篇:MySQL 系列(三)你不知道的 视图.触发器.存储过程.函数 ...
- Asp.Net Core 项目实战之权限管理系统(4) 依赖注入、仓储、服务的多项目分层实现
0 Asp.Net Core 项目实战之权限管理系统(0) 无中生有 1 Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端 2 Asp.Net Core 项目实战之 ...
- 给缺少Python项目实战经验的人
我们在学习过程中最容易犯的一个错误就是:看的多动手的少,特别是对于一些项目的开发学习就更少了! 没有一个完整的项目开发过程,是不会对整个开发流程以及理论知识有牢固的认知的,对于怎样将所学的理论知识应用 ...
- asp.net core 实战之 redis 负载均衡和"高可用"实现
1.概述 分布式系统缓存已经变得不可或缺,本文主要阐述如何实现redis主从复制集群的负载均衡,以及 redis的"高可用"实现, 呵呵双引号的"高可用"并不是 ...
- Linux实战教学笔记08:Linux 文件的属性(上半部分)
第八节 Linux 文件的属性(上半部分) 标签(空格分隔):Linux实战教学笔记 第1章 Linux中的文件 1.1 文件属性概述(ls -lhi) linux里一切皆文件 Linux系统中的文件 ...
- Linux实战教学笔记07:Linux系统目录结构介绍
第七节 Linux系统目录结构介绍 标签(空格分隔):Linux实战教学笔记 第1章 前言 windows目录结构 C:\windows D:\Program Files E:\你懂的\精品 F:\你 ...
- Linux实战教学笔记06:Linux系统基础优化
第六节 Linux系统基础优化 标签(空格分隔):Linux实战教学笔记-陈思齐 第1章 基础环境 第2章 使用网易163镜像做yum源 默认国外的yum源速度很慢,所以换成国内的. 第一步:先备份 ...
- Linux实战教学笔记05:远程SSH连接服务与基本排错(新手扫盲篇)
第五节 远程SSH连接服务与基本排错 标签(空格分隔):Linux实战教学笔记-陈思齐 第1章 远程连接LInux系统管理 1.1 为什么要远程连接Linux系统 在实际的工作场景中,虚拟机界面或物理 ...
随机推荐
- [K8s] Kubernetes 集群部署管理方式对比, kops, kubeadm, kubespray
kops 是官方出的 Kubernetes Operations,生产级 K8s 的安装.升级和管理. 可以看做是适用于集群的 kubectl,kops 可帮助您从命令行创建,销毁,升级和维护生产级, ...
- dotnet OpenXML 解析 PPT 里表格的样式
在 PPT 里面的表格可以通过表格样式配置决定表格的样式,本文将和大家介绍如何获取和解析表格的样式 本文属于 OpenXML 系列博客,有一定的上下文,详细请参阅 Office 使用 OpenXML ...
- 修复 VisualStudio 构建时没有将 NuGet 的 PDB 符号文件拷贝到输出文件夹
本文告诉大家如何修复 VisualStudio 构建时没有将 NuGet 的 PDB 符号文件拷贝到输出文件夹的问题.如果 VisualStudio 构建时没有将 NuGet 的 PDB 符号文件拷贝 ...
- UCenter 1.6 数据字典
uc_admins 管理员权限表 字段名 数据类型 默认值 允许非空 自动递增 备注 uid mediumint(8) unsigned NO 是 用户ID username char(15) ...
- RT-Thread 时钟管理
一.时钟节拍 任何操作系统都需要提供一个时钟节拍,以供系统处理所有和时间有关的事件,如线程的延时.线程的时间片轮转调度以及定时器超时等.时钟节拍是特定的周期性中断,这个中断可以看做是系统心跳,中断之间 ...
- 零侵入!试试这款Api接口文档生成器!
大家好,我是 Java陈序员. 作为一名合格的程序员,不仅代码要写好,而且文档要写好. 虽然目前有成熟的框架可以快速生成接口文档,如大名鼎鼎的 Swagger.但是 Swagger 需要编写大量的注解 ...
- Solution Set - LCT
A[洛谷P3690]维护一个森林,支持询问路径xor和,连边(已连通则忽略),删边(无边则忽略),改变点权. B[洛谷P3203]\(n\)个装置编号为\(0,...,n-1\),从\(i\)可以一步 ...
- 21°C的冬天
2023-12-08 16:15:36 星期五 标题没有在胡说,今天穿着初秋的衣服还嫌热,尤其是蒋震图书馆的空调更是燥热. 明天就去考教资面试了,但是一点也没有学习的兴趣,今天下午四点就写完了这周所有 ...
- 【工程实践】go语言实现MerkleTree
简介 默克尔树(MerkleTree)是一种典型的二叉树结构,其主要特点为: 最下面的叶节点包含存储数据或其哈希值: 非叶子节点(包括中间节点和根节点)的内容为它的两个孩子节点内容的哈希值. 所以底层 ...
- CSRF(Pikachu靶场练习)
CSRF(get) 自己随便输点东西,回显登录失败,查看源码没发现什么 点开提示,登录进去看看 看到可以修改个人信息,我们把居住改成China,修改成功,没发现urlhttp://127.0.0.1/ ...