CNN实现

概述

我在qwe中有两种,第一种是按照Ng课程中的写法,多层循环嵌套得到每次的“小方格”,然后WX+b,这样的做法是最简单,直观。但是效率极其慢。基本跑个10张以内图片都会卡的要死。

第二种方法是使用img2col,将其转换为对应的矩阵,然后直接做一次矩阵乘法运算。

先看第一种

def forward(self, X):
m, n_H_prev, n_W_prev, n_C_prev = X.shape
(f, f, n_C_prev, n_C) = self.W.shape
n_H = int((n_H_prev - f + 2 * self.pad) / self.stride) + 1
n_W = int((n_W_prev - f + 2 * self.pad) / self.stride) + 1
n_H, n_W, n_C = self.output_size Z = np.zeros((m, n_H, n_W, n_C))
X_pad = zero_pad(X, self.pad)
for i in range(m):
for h in range(n_H):
for w in range(n_W):
for c in range(n_C):
vert_start = h * self.stride
vert_end = vert_start + f
horiz_start = w * self.stride
horiz_end = horiz_start + f
A_slice_prev =X_pad[i,vert_start:vert_end, horiz_start:horiz_end, :]
Z[i,h,w,c] = conv_single_step(A_slice_prev, self.W[...,c], self.b[...,c]) def conv_single_step(X, W, b):
# 对一个裁剪图像进行卷积
# X.shape = f, f, prev_channel_size
return np.sum(np.multiply(X, W) + b)

对于m,n_H,n_W,n_C循环就是取得裁剪小方块,可以看到这里的计算复杂度m * n_H * n_W * n_C * (f*f的矩阵计算)

第二种方法,先转换成大矩阵,再进行一次矩阵运算,相当于节省了多次小矩阵运算时间,这还是很可观的,能查个几十倍的速度。

img2col原理很简单,详情可参考caffe im2col

就是循环将每一部分都拉长成一维矩阵拼凑起来。

对于CNN来说,H就是要计算方块的个数即m(样本数) n_H(最终生成图像行数)n_W(最终生成图像列数),W就是f(核kernel长)f(核宽)*(输入样本通道输)

然后还要把参数矩阵W也拉成这个样子,H就是f(核长)f(核宽)(输入样本通道输),W列数就是核数kernel_size

如下图



def img2col(X, pad, stride, f):
pass
ff = f * f
m, n_H_prev, n_W_prev, n_C_prev= X.shape
n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1
Z = np.zeros((m * n_H * n_W, f * f * n_C_prev))
X_pad = np.pad(X, ((0, 0), (pad, pad), (pad, pad), (0, 0)), 'constant', constant_values=0)
row = -1 for i in range(m):
for h in range(n_H):
for w in range(n_W):
row += 1
vert_start = h * stride
horiz_start = w * stride
for col in range(f * f * n_C_prev):
t = col // n_C_prev
hh = t // f
ww = t % f
cc = col % n_C_prev
Z[row, col] = X_pad[i, vert_start + hh, horiz_start + ww, cc] def speed_forward(model, X):
W = model.W
b = model.b
stride = model.stride
pad = model.pad
(n_C_prev, f, f, n_C) = W.shape
m, n_H_prev, n_W_prev, n_C_prev = X.shape n_H = int((n_H_prev - f + 2 * pad) / stride) + 1
n_W = int((n_W_prev - f + 2 * pad) / stride) + 1 # WW = W.swapaxes(2,1)
# WW = WW.swapaxes(1,0) XX = img2col(X, pad, stride, f)
# WW = WW.reshape(f*f*n_C_prev, n_C)
WW = W.reshape(f*f*n_C_prev, n_C)
model.XX = XX
model.WW = WW Z = np.dot(XX, WW) + b
return Z.reshape(m, n_H, n_W, n_C)

这种耗时操作,最好使用Cython扩展来写,不然速度还是不够理想。Cython扩展代码code

反向传播同理,具体代码参考

github

qwe框架- CNN 实现的更多相关文章

  1. 深度学习原理与框架-CNN在文本分类的应用 1.tf.nn.embedding_lookup(根据索引数据从数据中取出数据) 2.saver.restore(加载sess参数)

    1. tf.nn.embedding_lookup(W, X) W的维度为[len(vocabulary_list), 128], X的维度为[?, 8],组合后的维度为[?, 8, 128] 代码说 ...

  2. 深蓝色 --ppt

    Deep Learning of Binary Hash Codes for Fast Image Retrieval [Paper] [Code-Caffe] 1. 摘要 针对图像检索问题,提出简单 ...

  3. [基础]Deep Learning的基础概念

    目录 DNN CNN DNN VS CNN Example 卷积的好处why convolution? DCNN 卷积核移动的步长 stride 激活函数 active function 通道 cha ...

  4. qwe 简易深度框架

    qwe github地址 简介 简单的深度框架,参考Ng的深度学习课程作业,使用了keras的API设计. 方便了解网络具体实现,避免深陷于成熟框架的细节和一些晦涩的优化代码. 网络层实现了Dense ...

  5. 【深度学习系列3】 Mariana CNN并行框架与图像识别

    [深度学习系列3] Mariana CNN并行框架与图像识别 本文是腾讯深度学习系列文章的第三篇,聚焦于腾讯深度学习平台Mariana中深度卷积神经网络Deep CNNs的多GPU模型并行和数据并行框 ...

  6. 卷积神经网络CNN与深度学习常用框架的介绍与使用

    一.神经网络为什么比传统的分类器好 1.传统的分类器有 LR(逻辑斯特回归) 或者 linear SVM ,多用来做线性分割,假如所有的样本可以看做一个个点,如下图,有蓝色的点和绿色的点,传统的分类器 ...

  7. 我所写的CNN框架 VS caffe

    我所写的CNN框架 VS caffe 一个月前.自己模仿caffe实现了一个卷积神经网络的框架. 同样点 1无缝支持CPU和GPU模式,GPU模式使用cuda实现. 不同点 1我的CNN不依赖与不论什 ...

  8. ubuntu之路——day19.2 开源框架与迁移、CNN中的数据扩充

    开源框架与迁移 上面介绍了一些已经取得很好成绩的CNN框架,我们可以直接从GitHub上下载这些神经网络的结构和已经在ImageNet等数据集上训练好的权重超参数. 在应用于我们自己的数据时. 1.如 ...

  9. CNN基础框架简介

    卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...

随机推荐

  1. httpd添加新模块

    */ .hljs { display: block; overflow-x: auto; padding: 0.5em; color: #333; background: #f8f8f8; } .hl ...

  2. SpringMVC源码之Controller查找原理

    摘要 本文从源码层面简单讲解SpringMVC的处理器映射环节,也就是查找Controller详细过程. SpringMVC请求流程 Controller查找在上图中对应的步骤1至2的过程 Sprin ...

  3. dnion的remap.conf文件

    # # URL Remapping Config File # # Using remap.config allows you to accomplish two things: # # 1) Rew ...

  4. 计算机中RAM和ROM

    1.RAM(RamdomAccessMemory): 易挥发性随机存取存储器,高速存取,读写时间相等,且与地址无关,如计算机内存等. 2.ROM(Read Only Memory): 只读存储器.断电 ...

  5. 网页转图片--- html2canvas截图

    最近有个做在线名片(可保存图片至本地)的任务,特意研究了一下图片生成,也踩了几个坑.特此总结一下,顺便分享一下demo: 链接:https://pan.baidu.com/s/1o98UBJO 密码: ...

  6. C++——函数重载

    C++允许功能相近的函数在相同的作用域内以相同函数名声明,从而形成重载,方便使用,便于记忆. /*形参类型不同*/ int add(int x,int y); float add(float x,fl ...

  7. Java文件上传细讲

    什么是文件上传? 文件上传就是把用户的信息保存起来. 为什么需要文件上传? 在用户注册的时候,可能需要用户提交照片.那么这张照片就应该要进行保存. 上传组件(工具) 为什么我们要使用上传工具? 为啥我 ...

  8. 多线程编程学习笔记——编写一个异步的HTTP服务器和客户端

    接上文 多线程编程学习笔记——使用异步IO 二.   编写一个异步的HTTP服务器和客户端 本节展示了如何编写一个简单的异步HTTP服务器. 1.程序代码如下. using System; using ...

  9. z3 巧解CTF逆向题

    z3 巧解逆向题 题目下载链接:http://reversing.kr/download.php?n=7 这次实验的题目为Reversing.kr网站中的一道题目. 题目要求: ReversingKr ...

  10. Windows实用命令

     Windows实用命令   # 统计ESTABLISHED状态下的连接一共有多少个/c是统计行数,/i是忽略大小写 netstat -ano|find /i "established&qu ...