import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision import models,datasets
torch.__version__
'1.3.0'

4.2.2 使用Tensorboard在 PyTorch 中进行可视化

Tensorboard 简介

Tensorboard是tensorflow内置的一个可视化工具,它通过将tensorflow程序输出的日志文件的信息可视化使得tensorflow程序的理解、调试和优化更加简单高效。
Tensorboard的可视化依赖于tensorflow程序运行输出的日志文件,因而tensorboard和tensorflow程序在不同的进程中运行。
TensorBoard给我们提供了极其方便而强大的可视化环境。它可以帮助我们理解整个神经网络的学习过程、数据的分布、性能瓶颈等等。

tensorboard虽然是tensorflow内置的可视化工具,但是他们跑在不同的进程中,所以Github上已经有大神将tensorboard应用到Pytorch中 链接在这里

Tensorboard 安装

首先需要安装tensorboard

pip install tensorboard

~~ 然后再安装tensorboardx ~~

~~ pip install tensorboardx ~~
pytorch 1.1以后的版本内置了SummaryWriter 函数,所以不需要再安装tensorboardx了

安装完成后与 visdom一样执行独立的命令
tensorboard --logdir logs 即可启动,默认的端口是 6006,在浏览器中打开 http://localhost:6006/ 即可看到web页面。

这里要说明的是 微软的Edge浏览器css会无法加载,使用chrome正常显示

页面

与visdom不同,tensorboard针对不同的类型人为的区分多个标签,每一个标签页面代表不同的类型。
下面我们根据不同的页面功能做个简单的介绍,更多详细内容请参考官网

SCALAR

对标量数据进行汇总和记录,通常用来可视化训练过程中随着迭代次数准确率(val acc)、损失值(train/test loss)、学习率(learning rate)、每一层的权重和偏置的统计量(mean、std、max/min)等的变化曲线

IMAGES

可视化当前轮训练使用的训练/测试图片或者 feature maps

GRAPHS

可视化计算图的结构及计算图上的信息,通常用来展示网络的结构

HISTOGRAMS

可视化张量的取值分布,记录变量的直方图(统计张量随着迭代轮数的变化情况)

PROJECTOR

全称Embedding Projector 高维向量进行可视化

使用

在使用前请先去确认执行tensorboard --logdir logs 并保证 http://localhost:6006/ 页面能够正常打开

图像展示

首先介绍比较简单的功能,查看我们训练集和数据集中的图像,这里我们使用现成的图像作为展示。这里使用wikipedia上的一张猫的图片这里

引入 tensorboardX 包

# 这里的引用也要修改成torch的引用
#from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
cat_img = Image.open('./1280px-Felis_silvestris_catus_lying_on_rice_straw.jpg')
cat_img.size
(1280, 853)

这是一张1280x853的图,我们先把她变成224x224的图片,因为后面要使用的是vgg16

transform_224 = transforms.Compose([
transforms.Resize(224), # 这里要说明下 Scale 已经过期了,使用Resize
transforms.CenterCrop(224),
transforms.ToTensor(),
])
cat_img_224=transform_224(cat_img)

将图片展示在tebsorboard中:

writer = SummaryWriter(log_dir='./logs', comment='cat image') # 这里的logs要与--logdir的参数一样
writer.add_image("cat",cat_img_224)
writer.close()# 执行close立即刷新,否则将每120秒自动刷新

浏览器访问 http://localhost:6006/#images 即可看到猫的图片

更新损失函数

更新损失函数和训练批次我们与visdom一样使用模拟展示,这里用到的是tensorboard的SCALAR页面

x = torch.FloatTensor([100])
y = torch.FloatTensor([500]) for epoch in range(30):
x = x * 1.2
y = y / 1.1
loss = np.random.random()
with SummaryWriter(log_dir='./logs', comment='train') as writer: #可以直接使用python的with语法,自动调用close方法
writer.add_histogram('his/x', x, epoch)
writer.add_histogram('his/y', y, epoch)
writer.add_scalar('data/x', x, epoch)
writer.add_scalar('data/y', y, epoch)
writer.add_scalar('data/loss', loss, epoch)
writer.add_scalars('data/data_group', {'x': x,
'y': y}, epoch)

浏览器访问 http://localhost:6006/#scalars 即可看到图形

使用PROJECTOR对高维向量可视化

PROJECTOR的的原理是通过PCA,T-SNE等方法将高维向量投影到三维坐标系(降维度)。Embedding Projector从模型运行过程中保存的checkpoint文件中读取数据,默认使用主成分分析法(PCA)将高维数据投影到3D空间中,也可以通过设置设置选择T-SNE投影方法,这里做一个简单的展示。

我们还是用第三章的mnist代码

BATCH_SIZE=512
EPOCHS=20
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
# 1,28x28
self.conv1=nn.Conv2d(1,10,5) # 10, 24x24
self.conv2=nn.Conv2d(10,20,3) # 128, 10x10
self.fc1 = nn.Linear(20*10*10,500)
self.fc2 = nn.Linear(500,10)
def forward(self,x):
in_size = x.size(0)
out = self.conv1(x) #24
out = F.relu(out)
out = F.max_pool2d(out, 2, 2) #12
out = self.conv2(out) #10
out = F.relu(out)
out = out.view(in_size,-1)
out = self.fc1(out)
out = F.relu(out)
out = self.fc2(out)
out = F.log_softmax(out,dim=1)
return out
model = ConvNet()
optimizer = torch.optim.Adam(model.parameters())
def train(model, train_loader, optimizer, epoch):
n_iter=0
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if(batch_idx+1)%30 == 0:
n_iter=n_iter+1
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
#相对于以前的训练方法 主要增加了以下内容
out = torch.cat((output.data, torch.ones(len(output), 1)), 1) # 因为是投影到3D的空间,所以我们只需要3个维度
with SummaryWriter(log_dir='./logs', comment='mnist') as writer:
#使用add_embedding方法进行可视化展示
writer.add_embedding(
out,
metadata=target.data,
label_img=data.data,
global_step=n_iter)

这里节省时间,只训练一次

train(model, train_loader, optimizer, 0)
Train Epoch: 0 [14848/60000 (25%)]	Loss: 0.352312
Train Epoch: 0 [30208/60000 (50%)] Loss: 0.202950
Train Epoch: 0 [45568/60000 (75%)] Loss: 0.156494

打开 http://localhost:6006/#projector 即可看到效果。

目前测试投影这部分也是有问题的,根据官网文档的代码进行测试,也显示不出来,正在找原因

绘制网络结构

在pytorch中我们可以使用print直接打印出网络的结构,但是这种方法可视化效果不好,这里使用tensorboard的GRAPHS来实现网络结构的可视化。
由于pytorch使用的是动态图计算,所以我们这里要手动进行一次前向的传播.

使用Pytorch已经构建好的模型进行展示

vgg16 = models.vgg16(pretrained=True) # 这里下载预训练好的模型
print(vgg16) # 打印一下这个模型
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)

在前向传播前,先要把图片做一些调整

transform_2 = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

使用上一张猫的图片进行前向传播

vgg16_input=transform_2(cat_img)[np.newaxis]# 因为pytorch的是分批次进行的,所以我们这里建立一个批次为1的数据集
vgg16_input.shape
torch.Size([1, 3, 224, 224])

开始前向传播,打印输出值

out = vgg16(vgg16_input)
_, preds = torch.max(out.data, 1)
label=preds.numpy()[0]
label
287

将结构图在tensorboard进行展示

with SummaryWriter(log_dir='./logs', comment='vgg161') as writer:
writer.add_graph(vgg16, vgg16_input)

对于Pytorch的1.3版本来说,实测 SummaryWriter在处理结构图的时候是有问题的(或者是需要加什么参数,目前我还没找到),所以建议大家继续使用tensorboardx。

[Pytorch框架] 4.2.2 使用Tensorboard在 PyTorch 中进行可视化的更多相关文章

  1. PyTorch框架+Python 3面向对象编程学习笔记

    一.CNN情感分类中的面向对象部分 sparse.py super(Embedding, self).__init__() 表示需要父类初始化,即要运行父类的_init_(),如果没有这个,则要自定义 ...

  2. 手写数字识别 卷积神经网络 Pytorch框架实现

    MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...

  3. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  4. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  5. pytorch框架对RTX 2080Ti RTX 3090的支持与性能测试

    时间点:202011-18 一.背景 2020年9月nvidia发布了30系列的显卡.比起20系列网上的评价是:性能翻倍,价格减半. 最近正好本人手上有RTX 2080Ti 和 RTX 3090,所以 ...

  6. Tensorboard教程:监控指标可视化

    Tensorflow监控指标可视化 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 强烈推荐Tensorflow实战Google深度学习框架 实验平台: Tensorflow1.4. ...

  7. RobotFramework自动化测试框架-Selenium Web自动化(三)关于在RobotFramework中如何使用Selenium很全的总结(下)

    本文紧接着RobotFramework自动化测试框架-Selenium Web自动化(二)关于在RobotFramework中如何使用Selenium很全的总结(上)继续分享RobotFramewor ...

  8. “造轮运动”之 ORM框架系列(二)~ 说说我心目中的ORM框架

    ORM概念解析 首先梳理一下ORM的概念,ORM的全拼是Object Relation Mapping (对象关系映射),其中Object就是面向对象语言中的对象,本文使用的是c#语言,所以就是.ne ...

  9. Combine 框架,从0到1 —— 5.Combine 中的 Subjects

    本文首发于 Ficow Shen's Blog,原文地址: Combine 框架,从0到1 -- 5.Combine 中的 Subjects. 内容概览 前言 PassthroughSubject C ...

  10. ArXiv最受欢迎开源深度学习框架榜单:TensorFlow第一,PyTorch第四

    [导读]Kears作者François Chollet刚刚在Twitter贴出最近三个月在arXiv提到的深度学习框架,TensorFlow不出意外排名第一,Keras排名第二.随后是Caffe.Py ...

随机推荐

  1. SQLyog 13.1.1.0注册码证书秘钥

    注册信息: Name:(用户名随意) License Key: Professional: 8e053a86-cdd3-48ed-b5fe-94c51b3d343c Enterprise: a4668 ...

  2. 安装Win11需要网络才能下一步怎么跳过

    1.先Shift+F10打开命令提示符 2.运行C:\Windows\System32\oobe\BypassNRO.cmd 3.自动重启来到联网这一步,多了一个没有网络的选项,进入.

  3. arcengine动态显示所需字段值

    需求:实现和GIS桌面端中Identify的类似功能,鼠标滑动的时候可以显示鼠标所在位置的要素的指定字段的值.. 主要操作流程: ①先打开一个对话框,用于选择需要显示的图层和字段名 ②点击确定之后,在 ...

  4. 如果摄像头不支持Web Socket,猿大师播放器还能在网页中播放RTSP流吗?

    问: 我们的情况比较复杂,摄像头设备品牌和数量都比较多,分布在全国各地都有,地点分布比较广泛,有的甚至是比较老的型号,如果摄像头设备不支持Web Socket,猿大师播放器还可以在网页中播放RTSP流 ...

  5. 不使用setTimeout的延迟执行

    function sleep(ms){ var time = new Date(); time.setTime(time.getTime() + ms); while(new Date().getTi ...

  6. Java笔记第五弹

    编码表 将字符存储到计算机中----编码:反之,则为解码: GBK编码:最常用的中文码表 GB18030--最新的中文码表 Unicode字符集:业界的一种标准,也称为统一码.万国码 UTF-8编码: ...

  7. 除select外查询数据的另一种姿势

    1.24 1.[GYCTF2020]Blacklist buuctf上的题目 1.解题过程 输入1会返回一个数组,加上单引号就报错了,说明存在注入 以前做过类似的估计是堆叠注入,尝试一下 注入成功 正 ...

  8. PMP常见会议小结

    转载请注明出处: 会议是吸引项目团队和其他干系人参与的重要方式.它们是整个项目的主要沟通方式. 一. 项目启动会 召开时间:是启动阶段结束时召开的会议. 主要任务:发布项目章程,并任命项目经理,赋予项 ...

  9. 【备忘录】 主定理 Master Theorem (转载)

    备忘录 https://zhuanlan.zhihu.com/p/113406812

  10. 使用 zeromq与cppzmq 程序退出遇到的坑

    在使用zeromq 退出的时候还遇到一点坑,对于服务deaman(守护进程)化的进程可能会遇到这个问题. 现象: 这个问题导致的现象是服务一旦关闭(stop),就会 core dump,core du ...