使用NetworkX模块绘制深度神经网络(DNN)结构图
本文将展示如何利用Python中的NetworkX模块来绘制深度神经网络(DNN)结构图。
在文章Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中,我们创建的DNN结构图如下:

该DNN模型由输入层、隐藏层、输出层和softmax函数组成,每一层的神经元个数分别为4,5,6,3,3。不知道聪明的读者有没有发现,这张示意图完全是由笔者自己用Python绘制出来的,因为并不存在现成的结构图。那么,如何利用Python来绘制出这种相对复杂的神经网络的示意图呢?答案是利用NetworkX模块。
NetworkX是一个用Python语言开发的图论与复杂网络建模工具,内置了常用的图与复杂网络分析算法,可以方便地进行复杂网络数据分析、仿真建模等工作。NetworkX支持创建简单无向图、有向图和多重图,内置许多标准的图论算法,节点可为任意数据,支持任意的边值维度,功能丰富,简单易用。
首先,我们需要绘制出该DNN的大致框架,其Python代码如下:
# -*- coding:utf-8 -*-
import networkx as nx
import matplotlib.pyplot as plt
# 创建DAG
G = nx.DiGraph()
# 顶点列表
vertex_list = ['v'+str(i) for i in range(1, 22)]
# 添加顶点
G.add_nodes_from(vertex_list)
# 边列表
edge_list = [
('v1', 'v5'), ('v1', 'v6'), ('v1', 'v7'),('v1', 'v8'),('v1', 'v9'),
('v2', 'v5'), ('v2', 'v6'), ('v2', 'v7'),('v2', 'v8'),('v2', 'v9'),
('v3', 'v5'), ('v3', 'v6'), ('v3', 'v7'),('v3', 'v8'),('v3', 'v9'),
('v4', 'v5'), ('v4', 'v6'), ('v4', 'v7'),('v4', 'v8'),('v4', 'v9'),
('v5','v10'),('v5','v11'),('v5','v12'),('v5','v13'),('v5','v14'),('v5','v15'),
('v6','v10'),('v6','v11'),('v6','v12'),('v6','v13'),('v6','v14'),('v6','v15'),
('v7','v10'),('v7','v11'),('v7','v12'),('v7','v13'),('v7','v14'),('v7','v15'),
('v8','v10'),('v8','v11'),('v8','v12'),('v8','v13'),('v8','v14'),('v8','v15'),
('v9','v10'),('v9','v11'),('v9','v12'),('v9','v13'),('v9','v14'),('v9','v15'),
('v10','v16'),('v10','v17'),('v10','v18'),
('v11','v16'),('v11','v17'),('v11','v18'),
('v12','v16'),('v12','v17'),('v12','v18'),
('v13','v16'),('v13','v17'),('v13','v18'),
('v14','v16'),('v14','v17'),('v14','v18'),
('v15','v16'),('v15','v17'),('v15','v18'),
('v16','v19'),
('v17','v20'),
('v18','v21')
]
# 通过列表形式来添加边
G.add_edges_from(edge_list)
# 绘制DAG图
plt.title('DNN for iris') #图片标题
nx.draw(
G,
node_color = 'red', # 顶点颜色
edge_color = 'black', # 边的颜色
with_labels = True, # 显示顶点标签
font_size =10, # 文字大小
node_size =300 # 顶点大小
)
# 显示图片
plt.show()
可以看到,我们在代码中已经设置好了这22个神经元以及它们之间的连接情况,但绘制出来的结构如却是这样的:

这显然不是我们想要的结果,因为各神经的连接情况不明朗,而且很多神经都挤在了一起,看不清楚。之所以出现这种情况,是因为我们没有给神经元设置坐标,导致每个神经元都是随机放置的。
接下来,引入坐标机制,即设置好每个神经元节点的坐标,使得它们的位置能够按照事先设置好的来放置,其Python代码如下:
# -*- coding:utf-8 -*-
import networkx as nx
import matplotlib.pyplot as plt
# 创建DAG
G = nx.DiGraph()
# 顶点列表
vertex_list = ['v'+str(i) for i in range(1, 22)]
# 添加顶点
G.add_nodes_from(vertex_list)
# 边列表
edge_list = [
('v1', 'v5'), ('v1', 'v6'), ('v1', 'v7'),('v1', 'v8'),('v1', 'v9'),
('v2', 'v5'), ('v2', 'v6'), ('v2', 'v7'),('v2', 'v8'),('v2', 'v9'),
('v3', 'v5'), ('v3', 'v6'), ('v3', 'v7'),('v3', 'v8'),('v3', 'v9'),
('v4', 'v5'), ('v4', 'v6'), ('v4', 'v7'),('v4', 'v8'),('v4', 'v9'),
('v5','v10'),('v5','v11'),('v5','v12'),('v5','v13'),('v5','v14'),('v5','v15'),
('v6','v10'),('v6','v11'),('v6','v12'),('v6','v13'),('v6','v14'),('v6','v15'),
('v7','v10'),('v7','v11'),('v7','v12'),('v7','v13'),('v7','v14'),('v7','v15'),
('v8','v10'),('v8','v11'),('v8','v12'),('v8','v13'),('v8','v14'),('v8','v15'),
('v9','v10'),('v9','v11'),('v9','v12'),('v9','v13'),('v9','v14'),('v9','v15'),
('v10','v16'),('v10','v17'),('v10','v18'),
('v11','v16'),('v11','v17'),('v11','v18'),
('v12','v16'),('v12','v17'),('v12','v18'),
('v13','v16'),('v13','v17'),('v13','v18'),
('v14','v16'),('v14','v17'),('v14','v18'),
('v15','v16'),('v15','v17'),('v15','v18'),
('v16','v19'),
('v17','v20'),
('v18','v21')
]
# 通过列表形式来添加边
G.add_edges_from(edge_list)
# 指定绘制DAG图时每个顶点的位置
pos = {
'v1':(-2,1.5),
'v2':(-2,0.5),
'v3':(-2,-0.5),
'v4':(-2,-1.5),
'v5':(-1,2),
'v6': (-1,1),
'v7':(-1,0),
'v8':(-1,-1),
'v9':(-1,-2),
'v10':(0,2.5),
'v11':(0,1.5),
'v12':(0,0.5),
'v13':(0,-0.5),
'v14':(0,-1.5),
'v15':(0,-2.5),
'v16':(1,1),
'v17':(1,0),
'v18':(1,-1),
'v19':(2,1),
'v20':(2,0),
'v21':(2,-1)
}
# 绘制DAG图
plt.title('DNN for iris') #图片标题
plt.xlim(-2.2, 2.2) #设置X轴坐标范围
plt.ylim(-3, 3) #设置Y轴坐标范围
nx.draw(
G,
pos = pos, # 点的位置
node_color = 'red', # 顶点颜色
edge_color = 'black', # 边的颜色
with_labels = True, # 显示顶点标签
font_size =10, # 文字大小
node_size =300 # 顶点大小
)
# 显示图片
plt.show()
可以看到,在代码中,通过pos字典已经规定好了每个神经元节点的位置,那么,绘制好的DNN结构示意图如下:

可以看到,现在这个DNN模型的结构已经大致显现出来了。
接下来,我们需要对这个框架图进行更为细致地修改,需要修改的地方为:
- 去掉神经元节点的标签;
- 添加模型层的文字注释(比如Input layer).
其中,第二步的文字注释,我们借助opencv来完成。完整的Python代码如下:
# -*- coding:utf-8 -*-
import cv2
import networkx as nx
import matplotlib.pyplot as plt
# 创建DAG
G = nx.DiGraph()
# 顶点列表
vertex_list = ['v'+str(i) for i in range(1, 22)]
# 添加顶点
G.add_nodes_from(vertex_list)
# 边列表
edge_list = [
('v1', 'v5'), ('v1', 'v6'), ('v1', 'v7'),('v1', 'v8'),('v1', 'v9'),
('v2', 'v5'), ('v2', 'v6'), ('v2', 'v7'),('v2', 'v8'),('v2', 'v9'),
('v3', 'v5'), ('v3', 'v6'), ('v3', 'v7'),('v3', 'v8'),('v3', 'v9'),
('v4', 'v5'), ('v4', 'v6'), ('v4', 'v7'),('v4', 'v8'),('v4', 'v9'),
('v5','v10'),('v5','v11'),('v5','v12'),('v5','v13'),('v5','v14'),('v5','v15'),
('v6','v10'),('v6','v11'),('v6','v12'),('v6','v13'),('v6','v14'),('v6','v15'),
('v7','v10'),('v7','v11'),('v7','v12'),('v7','v13'),('v7','v14'),('v7','v15'),
('v8','v10'),('v8','v11'),('v8','v12'),('v8','v13'),('v8','v14'),('v8','v15'),
('v9','v10'),('v9','v11'),('v9','v12'),('v9','v13'),('v9','v14'),('v9','v15'),
('v10','v16'),('v10','v17'),('v10','v18'),
('v11','v16'),('v11','v17'),('v11','v18'),
('v12','v16'),('v12','v17'),('v12','v18'),
('v13','v16'),('v13','v17'),('v13','v18'),
('v14','v16'),('v14','v17'),('v14','v18'),
('v15','v16'),('v15','v17'),('v15','v18'),
('v16','v19'),
('v17','v20'),
('v18','v21')
]
# 通过列表形式来添加边
G.add_edges_from(edge_list)
# 指定绘制DAG图时每个顶点的位置
pos = {
'v1':(-2,1.5),
'v2':(-2,0.5),
'v3':(-2,-0.5),
'v4':(-2,-1.5),
'v5':(-1,2),
'v6': (-1,1),
'v7':(-1,0),
'v8':(-1,-1),
'v9':(-1,-2),
'v10':(0,2.5),
'v11':(0,1.5),
'v12':(0,0.5),
'v13':(0,-0.5),
'v14':(0,-1.5),
'v15':(0,-2.5),
'v16':(1,1),
'v17':(1,0),
'v18':(1,-1),
'v19':(2,1),
'v20':(2,0),
'v21':(2,-1)
}
# 绘制DAG图
plt.title('DNN for iris') #图片标题
plt.xlim(-2.2, 2.2) #设置X轴坐标范围
plt.ylim(-3, 3) #设置Y轴坐标范围
nx.draw(
G,
pos = pos, # 点的位置
node_color = 'red', # 顶点颜色
edge_color = 'black', # 边的颜色
font_size =10, # 文字大小
node_size =300 # 顶点大小
)
# 保存图片,图片大小为640*480
plt.savefig('E://data/DNN_sketch.png')
# 利用opencv模块对DNN框架添加文字注释
# 读取图片
imagepath = 'E://data/DNN_sketch.png'
image = cv2.imread(imagepath, 1)
# 输入层
cv2.rectangle(image, (85, 130), (120, 360), (255,0,0), 2)
cv2.putText(image, "Input Layer", (15, 390), 1, 1.5, (0, 255, 0), 2, 1)
# 隐藏层
cv2.rectangle(image, (190, 70), (360, 420), (255,0,0), 2)
cv2.putText(image, "Hidden Layer", (210, 450), 1, 1.5, (0, 255, 0), 2, 1)
# 输出层
cv2.rectangle(image, (420, 150), (460, 330), (255,0,0), 2)
cv2.putText(image, "Output Layer", (380, 360), 1, 1.5, (0, 255, 0), 2, 1)
# sofrmax层
cv2.rectangle(image, (530, 150), (570, 330), (255,0,0), 2)
cv2.putText(image, "Softmax Func", (450, 130), 1, 1.5, (0, 0, 255), 2, 1)
# 保存修改后的图片
cv2.imwrite('E://data/DNN.png', image)
这样生成的图片就是文章最开始给出的DNN的结构示意图。Bingo,搞定!
注意:本人现已开通微信公众号: Python爬虫与算法(微信号为:easy_web_scrape), 欢迎大家关注哦~~
使用NetworkX模块绘制深度神经网络(DNN)结构图的更多相关文章
- 深度神经网络DNN的多GPU数据并行框架 及其在语音识别的应用
深度神经网络(Deep Neural Networks, 简称DNN)是近年来机器学习领域中的研究热点,产生了广泛的应用.DNN具有深层结构.数千万参数需要学习,导致训练非常耗时.GPU有强大的计算能 ...
- 一天搞懂深度学习-训练深度神经网络(DNN)的要点
前言 这是<一天搞懂深度学习>的第二部分 一.选择合适的损失函数 典型的损失函数有平方误差损失函数和交叉熵损失函数. 交叉熵损失函数: 选择不同的损失函数会有不同的训练效果 二.mini- ...
- 深度神经网络(DNN)模型与前向传播算法
深度神经网络(Deep Neural Networks, 以下简称DNN)是深度学习的基础,而要理解DNN,首先我们要理解DNN模型,下面我们就对DNN的模型与前向传播算法做一个总结. 1. 从感知机 ...
- 神经网络6_CNN(卷积神经网络)、RNN(循环神经网络)、DNN(深度神经网络)概念区分理解
sklearn实战-乳腺癌细胞数据挖掘(博客主亲自录制视频教程,QQ:231469242) https://study.163.com/course/introduction.htm?courseId ...
- 深度神经网络(DNN)
深度神经网络(DNN) 深度神经网络(Deep Neural Networks, 以下简称DNN)是深度学习的基础,而要理解DNN,首先我们要理解DNN模型,下面我们就对DNN的模型与前向传播算法做一 ...
- 云中的机器学习:FPGA 上的深度神经网络
人工智能正在经历一场变革,这要得益于机器学习的快速进步.在机器学习领域,人们正对一类名为“深度学习”算法产生浓厚的兴趣,因为这类算法具有出色的大数据集性能.在深度学习中,机器可以在监督或不受监督的方式 ...
- Keras入门(一)搭建深度神经网络(DNN)解决多分类问题
Keras介绍 Keras是一个开源的高层神经网络API,由纯Python编写而成,其后端可以基于Tensorflow.Theano.MXNet以及CNTK.Keras 为支持快速实验而生,能够把 ...
- CNN(卷积神经网络)、RNN(循环神经网络)、DNN(深度神经网络)的内部网络结构有什么区别?
https://www.zhihu.com/question/34681168 CNN(卷积神经网络).RNN(循环神经网络).DNN(深度神经网络)的内部网络结构有什么区别?修改 CNN(卷积神经网 ...
- 深度神经网络(DNN)反向传播算法(BP)
在深度神经网络(DNN)模型与前向传播算法中,我们对DNN的模型和前向传播算法做了总结,这里我们更进一步,对DNN的反向传播算法(Back Propagation,BP)做一个总结. 1. DNN反向 ...
随机推荐
- Golang处理数据库的nil数据
在用golang获取数据库的数据的时候,难免会遇到可控field.这个时候拿到的数据如果直接用string, time.Time这样的类型来解析的话会遇到panic. 那么如何处理这个问题呢,第一个出 ...
- sort()方法的应用(二)
引用:函数作为参数 var fn_by = function(id) { return function(o, p) { var a, b; if (typeof o === "object ...
- ie页面数据导入共享版
为了解决自动输入号码的正确率,原来的版本一直采用鼠标检测的方法.但是这个方法在其他ie平台的使用不太方便.于是直接检测ie的方法.现在的这个版本完全不需要鼠标的检测.方便而且快速精准可靠. 经过作者的 ...
- Linux下使用openVPN连接到某个内网
推荐一个网站(比较全的介绍关于openvpn的客户端与服务端的配置) 点击我 此处我介绍我配置openvpn客户端连接的坑 我的机器为kali linux apt-get install openvp ...
- 破解StarUML3.01最新版 for Linux(Ubuntu16LTS)
原文地址:https://blog.csdn.net/yoyofreeman/article/details/80844739 chmod +x StarUML-3.0.1-x86_64.AppIma ...
- Appium + Java 测试 [百度地图] APP的一段简单脚本
1. 流程 进入 app ,手动处理前段预处理,程序一直等候到达指定搜索地名页面,填入[南通大学],点击[搜索] 2. Java 脚本 // part 1: 引入需要的包 import io.appi ...
- 微信技术分享:微信的海量IM聊天消息序列号生成实践(算法原理篇)
1.点评 对于IM系统来说,如何做到IM聊天消息离线差异拉取(差异拉取是为了节省流量).消息多端同步.消息顺序保证等,是典型的IM技术难点. 就像即时通讯网整理的以下IM开发干货系列一样: <I ...
- Day6:html和css
Day6:html和css 复习 margin: 0; padding: 0; <!DOCTYPE html> <html lang="en"> <h ...
- 性能调优之Mapping
Mapping层级的调优可能会花费时间,但是性能调优的效果确实非常显著的 优化Target,Source之后,可以调优Mapping 通常的方法是尽可能减少组件及组件的字段间不必要的连线 即尽可能用最 ...
- c++模板参数——数值类型推断
模板类中,或模板函数中,若限定模板参数为数值类型,可以使用如下方式进行判断. template<typename T> Fmt::Fmt(const char *fmt, T val) { ...