qwe框架- CNN 实现
CNN实现
概述
我在qwe中有两种,第一种是按照Ng课程中的写法,多层循环嵌套得到每次的“小方格”,然后WX+b,这样的做法是最简单,直观。但是效率极其慢。基本跑个10张以内图片都会卡的要死。
第二种方法是使用img2col,将其转换为对应的矩阵,然后直接做一次矩阵乘法运算。
先看第一种
def forward(self, X):
m, n_H_prev, n_W_prev, n_C_prev = X.shape
(f, f, n_C_prev, n_C) = self.W.shape
n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
n_H, n_W, n_C = self.output_size
Z = np.zeros((m, n_H, n_W, n_C))
X_pad = zero_pad(X, self.pad)
for i in range(m):
for h in range(n_H):
for w in range(n_W):
for c in range(n_C):
vert_start = h * self.stride
vert_end = vert_start + f
horiz_start = w * self.stride
horiz_end = horiz_start + f
A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c])
def conv_single_step(X, W, b):
# 对一个裁剪图像进行卷积
# X.shape = f, f, prev_channel_size
return np.sum(np.multiply(X, W) + b)
对于m,n_H,n_W,n_C循环就是取得裁剪小方块,可以看到这里的计算复杂度m * n_H * n_W * n_C * (f*f的矩阵计算)
第二种方法,先转换成大矩阵,再进行一次矩阵运算,相当于节省了多次小矩阵运算时间,这还是很可观的,能查个几十倍的速度。
img2col原理很简单,详情可参考caffe im2col
就是循环将每一部分都拉长成一维矩阵拼凑起来。
对于CNN来说,H就是要计算方块的个数即m(样本数) n_H(最终生成图像行数)n_W(最终生成图像列数),W就是f(核kernel长)f(核宽)*(输入样本通道输)
然后还要把参数矩阵W也拉成这个样子,H就是f(核长)f(核宽)(输入样本通道输),W列数就是核数kernel_size
如下图
def img2col(X, pad, stride, f):
pass
ff = f * f
m, n_H_prev, n_W_prev, n_C_prev= X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
row = -1
for i in range(m):
for h in range(n_H):
for w in range(n_W):
row += 1
vert_start = h * stride
horiz_start = w * stride
for col in range(f * f * n_C_prev):
t = col // n_C_prev
hh = t // f
ww = t % f
cc = col % n_C_prev
Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc]
def speed_forward(model, X):
W = model.W
b = model.b
stride = model.stride
pad = model.pad
(n_C_prev, f, f, n_C) = W.shape
m, n_H_prev, n_W_prev, n_C_prev = X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
# WW = W.swapaxes(2,1)
# WW = WW.swapaxes(1,0)
XX = img2col(X, pad, stride, f)
# WW = WW.reshape(f*f*n_C_prev, n_C)
WW = W.reshape(f*f*n_C_prev, n_C)
model.XX = XX
model.WW = WW
Z = np.dot(XX, WW) + b
return Z.reshape(m, n_H, n_W, n_C)
这种耗时操作,最好使用Cython扩展来写,不然速度还是不够理想。Cython扩展代码code
反向传播同理,具体代码参考
github
qwe框架- CNN 实现的更多相关文章
- 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)
1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...
- 深蓝色 --ppt
Deep Learning of Binary Hash Codes for Fast Image Retrieval [Paper] [Code-Caffe] 1. 摘要 针对图像检索问题,提出简单 ...
- [基础]Deep Learning的基础概念
目录 DNN CNN DNN VS CNN Example 卷积的好处why convolution? DCNN 卷积核移动的步长 stride 激活函数 active function 通道 cha ...
- qwe 简易深度框架
qwe github地址 简介 简单的深度框架,参考Ng的深度学习课程作业,使用了keras的API设计. 方便了解网络具体实现,避免深陷于成熟框架的细节和一些晦涩的优化代码. 网络层实现了Dense ...
- 【深度学习系列3】 Mariana CNN并行框架与图像识别
[深度学习系列3] Mariana CNN并行框架与图像识别 本文是腾讯深度学习系列文章的第三篇,聚焦于腾讯深度学习平台Mariana中深度卷积神经网络Deep CNNs的多GPU模型并行和数据并行框 ...
- 卷积神经网络CNN与深度学习常用框架的介绍与使用
一.神经网络为什么比传统的分类器好 1.传统的分类器有 LR(逻辑斯特回归) 或者 linear SVM ,多用来做线性分割,假如所有的样本可以看做一个个点,如下图,有蓝色的点和绿色的点,传统的分类器 ...
- 我所写的CNN框架 VS caffe
我所写的CNN框架 VS caffe 一个月前.自己模仿caffe实现了一个卷积神经网络的框架. 同样点 1无缝支持CPU和GPU模式,GPU模式使用cuda实现. 不同点 1我的CNN不依赖与不论什 ...
- ubuntu之路——day19.2 开源框架与迁移、CNN中的数据扩充
开源框架与迁移 上面介绍了一些已经取得很好成绩的CNN框架,我们可以直接从GitHub上下载这些神经网络的结构和已经在ImageNet等数据集上训练好的权重超参数. 在应用于我们自己的数据时. 1.如 ...
- CNN基础框架简介
卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...
随机推荐
- eclipse每次闪退后都提示查看\workspace\.metadata\.log
错误如下: 找到<workspace>/.metadata/.plugins/org.eclipse.e4.workbench/workbench.xmi"文件,将其删掉,再重启 ...
- exit、_exit、abort、return的区别
转自:http://www.cnblogs.com/fixer/archive/2013/05/14/3078660.html _exit(): 跟exit功能大致相同,区别在于_exit不会清空所有 ...
- Python字典(dict)使用技巧
字典dict是Python中使用频率非常高的数据结构,关于它的使用,也有许多的小技巧,掌握这些小技巧会让你高效地的使用dict,也会让你的代码更简洁. 1.默认值 假设name_for_userid存 ...
- ABP官方文档翻译 3.3 仓储
仓储 默认仓储 自定义仓储 自定义仓储接口 自定义仓储实现 基础仓储方法管理数据库连接 查询 获取单个实体 获取实体列表 关于IQueryable 自定义返回值 插入 更新 删除 其他 关于异步方法 ...
- 大白话说Java反射:入门、使用、原理
文章首发于[博客园-陈树义],点击跳转到原文<大白话说Java反射:入门.进阶.原理> 反射之中包含了一个「反」字,所以想要解释反射就必须先从「正」开始解释. 一般情况下,我们使用某个类时 ...
- BZOJ 2957: 楼房重建 [线段树 信息合并]
传送门 题意:转换成斜率然后维护区间的上升序列(从区间第一个数开始的单调上升序列) 区间保存这个区间的最长序列的长度$ls$和最大值$mx$ 如何合并两个区间信息? 左区间一定选择,右区间递归寻找第一 ...
- BZOJ 1180: [CROATIAN2009]OTOCI [LCT]
1180: [CROATIAN2009]OTOCI Time Limit: 50 Sec Memory Limit: 162 MBSubmit: 961 Solved: 594[Submit][S ...
- HTML标签的命名/CSS标准化命名大全
在一个内容较多的HTML页面中,需要设计许多不同的框架,再为这些不同的框架及内容进行分类,给予相应的名称,从而使得网页结构更加清晰,也为工作提供了方便.许多新手朋友在设计一个HTML文件时,可能只会依 ...
- win10安装elementary os双系统
elementary os是ubuntu的一个分支,界面有点像苹果,比较漂亮.如图: 从已有的磁盘中划出一块空白分区,将elementary单独安装在这个分区里,这个分区需要比其他分区的剩余空间都要大 ...
- [HAOI2009]毛毛虫
题目描述 对于一棵树,我们可以将某条链和与该链相连的边抽出来,看上去就象成一个毛毛虫,点数越多,毛毛虫就越大.例如下图左边的树(图 1 )抽出一部分就变成了右边的一个毛毛虫了(图 2 ). 输入输出格 ...