qwe框架- CNN 实现
CNN实现
概述
我在qwe中有两种,第一种是按照Ng课程中的写法,多层循环嵌套得到每次的“小方格”,然后WX+b,这样的做法是最简单,直观。但是效率极其慢。基本跑个10张以内图片都会卡的要死。
第二种方法是使用img2col,将其转换为对应的矩阵,然后直接做一次矩阵乘法运算。
先看第一种
def forward(self, X):
m, n_H_prev, n_W_prev, n_C_prev = X.shape
(f, f, n_C_prev, n_C) = self.W.shape
n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
n_H, n_W, n_C = self.output_size
Z = np.zeros((m, n_H, n_W, n_C))
X_pad = zero_pad(X, self.pad)
for i in range(m):
for h in range(n_H):
for w in range(n_W):
for c in range(n_C):
vert_start = h * self.stride
vert_end = vert_start + f
horiz_start = w * self.stride
horiz_end = horiz_start + f
A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c])
def conv_single_step(X, W, b):
# 对一个裁剪图像进行卷积
# X.shape = f, f, prev_channel_size
return np.sum(np.multiply(X, W) + b)
对于m,n_H,n_W,n_C循环就是取得裁剪小方块,可以看到这里的计算复杂度m * n_H * n_W * n_C * (f*f的矩阵计算)
第二种方法,先转换成大矩阵,再进行一次矩阵运算,相当于节省了多次小矩阵运算时间,这还是很可观的,能查个几十倍的速度。
img2col原理很简单,详情可参考caffe im2col
就是循环将每一部分都拉长成一维矩阵拼凑起来。
对于CNN来说,H就是要计算方块的个数即m(样本数) n_H(最终生成图像行数)n_W(最终生成图像列数),W就是f(核kernel长)f(核宽)*(输入样本通道输)
然后还要把参数矩阵W也拉成这个样子,H就是f(核长)f(核宽)(输入样本通道输),W列数就是核数kernel_size
如下图
def img2col(X, pad, stride, f):
pass
ff = f * f
m, n_H_prev, n_W_prev, n_C_prev= X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
row = -1
for i in range(m):
for h in range(n_H):
for w in range(n_W):
row += 1
vert_start = h * stride
horiz_start = w * stride
for col in range(f * f * n_C_prev):
t = col // n_C_prev
hh = t // f
ww = t % f
cc = col % n_C_prev
Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc]
def speed_forward(model, X):
W = model.W
b = model.b
stride = model.stride
pad = model.pad
(n_C_prev, f, f, n_C) = W.shape
m, n_H_prev, n_W_prev, n_C_prev = X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
# WW = W.swapaxes(2,1)
# WW = WW.swapaxes(1,0)
XX = img2col(X, pad, stride, f)
# WW = WW.reshape(f*f*n_C_prev, n_C)
WW = W.reshape(f*f*n_C_prev, n_C)
model.XX = XX
model.WW = WW
Z = np.dot(XX, WW) + b
return Z.reshape(m, n_H, n_W, n_C)
这种耗时操作,最好使用Cython扩展来写,不然速度还是不够理想。Cython扩展代码code
反向传播同理,具体代码参考
github
qwe框架- CNN 实现的更多相关文章
- 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)
1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...
- 深蓝色 --ppt
Deep Learning of Binary Hash Codes for Fast Image Retrieval [Paper] [Code-Caffe] 1. 摘要 针对图像检索问题,提出简单 ...
- [基础]Deep Learning的基础概念
目录 DNN CNN DNN VS CNN Example 卷积的好处why convolution? DCNN 卷积核移动的步长 stride 激活函数 active function 通道 cha ...
- qwe 简易深度框架
qwe github地址 简介 简单的深度框架,参考Ng的深度学习课程作业,使用了keras的API设计. 方便了解网络具体实现,避免深陷于成熟框架的细节和一些晦涩的优化代码. 网络层实现了Dense ...
- 【深度学习系列3】 Mariana CNN并行框架与图像识别
[深度学习系列3] Mariana CNN并行框架与图像识别 本文是腾讯深度学习系列文章的第三篇,聚焦于腾讯深度学习平台Mariana中深度卷积神经网络Deep CNNs的多GPU模型并行和数据并行框 ...
- 卷积神经网络CNN与深度学习常用框架的介绍与使用
一.神经网络为什么比传统的分类器好 1.传统的分类器有 LR(逻辑斯特回归) 或者 linear SVM ,多用来做线性分割,假如所有的样本可以看做一个个点,如下图,有蓝色的点和绿色的点,传统的分类器 ...
- 我所写的CNN框架 VS caffe
我所写的CNN框架 VS caffe 一个月前.自己模仿caffe实现了一个卷积神经网络的框架. 同样点 1无缝支持CPU和GPU模式,GPU模式使用cuda实现. 不同点 1我的CNN不依赖与不论什 ...
- ubuntu之路——day19.2 开源框架与迁移、CNN中的数据扩充
开源框架与迁移 上面介绍了一些已经取得很好成绩的CNN框架,我们可以直接从GitHub上下载这些神经网络的结构和已经在ImageNet等数据集上训练好的权重超参数. 在应用于我们自己的数据时. 1.如 ...
- CNN基础框架简介
卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...
随机推荐
- angular4在prod模式下的JIT编译问题
最近利用angular4开发一个项目,由于画面中的显示都是从数据表中读取,通过设置显示FLAG和显示顺序对画面布局按既定规则控制的, 所以必须利用动态编译实现. 方法如下, 1,获取JitCompil ...
- JVM的内存分区
JVM的内存分区 这篇文章尝试讨论清楚JVM的内存分区情况. 1. JVM的内存和系统内存的关系 下图是对系统内存及JVM内存的大致描绘 对大多数操作系统,内存可以分为物理内存RAM及Sw ...
- bat自动打包压缩实现
1.引言 本文档的编辑目的是为了实bat脚本自动打包功能,包含包的名字命名,压缩文件内外层文件夹的名字:包含svn版本号等: 2.实现介绍 (1)获取svn号,生成批处理文件 写一个pak.bat文件 ...
- 浅谈对SpringMVC的认识
SpringMVC概念: 他是一个轻量级的开源框架,应用于表现层,基于MVC的设计模式. SpringMVC的特点: 1.他是单例的可以设置成多例. 2.他的线程是安全的 ...
- 豆瓣爬虫小记(lowB版)
爬虫小记 学习玩python正则之后,想利用正则知识学学网络爬虫. 需求分析 按照自己平时浏览的网页,留意下哪个网页的信息对自己有价值,分析要爬取哪些网页信息.这里我先爬取简单的静态网页,豆瓣网经典电 ...
- Java垃圾回收机制[转]
原文地址:http://blog.csdn.net/zsuguangh/article/details/6429592 综合了若干人的blog- 1. 垃圾回收的意义 在C++中,对象所占的内存在程序 ...
- HDU 4315 Climbing the Hill [阶梯Nim]
传送门 题意: 和上题基本一样:山顶可以有多人,谁先把king放到山顶谁就胜 并不太明白 #include <iostream> #include <cstdio> #incl ...
- mysql DML DDL DCL
DML(data manipulation language): 它们是SELECT.UPDATE.INSERT.DELETE,就象它的名字一样,这4条命令是用来对数据库里的数据进行操作的语言 D ...
- 恢复linux系统文件夹颜色
/etc/DIR_COLORS 默认值 # Background color codes:# 40=black 41=red 42=green 43=yellow 44=blue 45=magenta ...
- MTU介绍以及在windows和linux下怎么设置MTU值
最大传输单元MTU(Maximum Transmission Unit,MTU)是指一种通信协议的某一层上面所能通过的最大数据包大小(以字节为单位).最大传输单元这个参数通常与通信接口有关(网络接口卡 ...