NN入门,手把手教你用Numpy手撕NN(三)
NN入门,手把手教你用Numpy手撕NN(3)
这是一篇包含极少数学的CNN入门文章
上篇文章中简单介绍了NN的反向传播,并利用反向传播实现了一个简单的NN,在这篇文章中将介绍一下CNN。
CNN
CV(计算机视觉)作为AI的一大研究方向,越来越多的人选择了这个方向,其中使用的深度学习的方法基本以卷积神经网络(CNN)为基础。因此,这篇文章将介绍CNN的实现。
CNN与我们之前介绍的NN的相比,出现了卷积层(Convolution层)和池化层(Pooling层)。其网络架构大致如下图所示
卷积层
与全连接神经网络的对比
为什么在有存在全连接神经网络的情况下还会出现卷积神经网络呢?
这就得来看看全连接神经网络存在的问题了,全连接神经网络中存在的问题是数据的形状被“忽略了”。比如,输入的数据是图像时,图像通常是高、长、通道方向上的3为形状。
但是,全连接层输入时,需要将3维数据拉平为1维数据,因此,我们可能会丢失图像数据中存在有的空间信息(如空间上临近像素为相似的值、RGB的各个通道之间的关联性、相距较远像素之间的关联性等),这些信息都会被全连接层丢失。
而卷积层可以保持形状不变,将输入的数据以相同的维度输出,因此,可能可以正确理解图像等具有形状的数据。
卷积运算
卷积层进行的处理就是卷积运算,运算方式如下图所示
一般在计算中也会加上偏置
填充
在进行卷积层的处理之前,有时要向输入数据的周围填入固定的数据(如0),称为填充(padding)。如下图所示
向输入数据的周围填入0,将大小为(4, 4)的输入数据变成了(6, 6)的形状。
为什么要进行填充操作?
在对大小为(4, 4)的输入数据使用(3, 3)的滤波器时,输出的大小会变成(2, 2),如果反复进行多次卷积运算,在某个时刻输出大小就有可能变成1,导致无法再应用卷积运算。为了避免这种情况,就要使用填充,使得卷积运算可以在保持空间大小不变的情况下将数据传给下一层。
步幅
应用滤波器的位置间隔称为步幅(stride)。在上面的例子中,步幅都为1,如果将步幅设为2,则如下图所示
可以发现,增大步幅后,输出大小会变小,增大填充后,输出大小会变大。
假设输入大小为(H, W),滤波器大小为(FH, FW),输出大小为(OH, OW),填充为P,步幅为S,则输出大小可以表示为
OW=\frac{W+2P-FW}{S}+1
\]
三维数据卷积运算
上面卷积运算的例子都是二维的数据,但是,一般来说我们的图像数据都是三维的,除了高、宽之外,还需要处理通道方向的数据。如下图所示
其计算方式为每个通道处的数据与对应通道的卷积核相乘,最后将各个通道得到的结果相加,从而得到输出。
这里需要注意的是,一般情况下,卷积核的通道数需要与输入数据的通道数相同。(有时会使用1x1卷积核来对通道数进行降/升维操作)。可参考这篇文章
上面给出的例子输出的结果还是一个通道的,如果我们想要输出的结果在通道上也有多个输出,该怎么做呢?如下图所示
即使用多个卷积核
池化层
池化是缩小高、长方向上的空间的运算。比如下图所示的最大池化
除了上图所示的Max池化之外,还有Average池化。一般来说,池化的窗口大小会和步幅设定成相同的值。
im2col
从前面的例子来看,会发现,如果完全按照计算过程来写代码的话,要用上好几层for循环,这样的话不仅写起来麻烦,估计在运行的时候计算速度也很慢。这里将介绍im2col的方法。
im2col将输入数据展开以适合卷积核,如下图所示
对3维的输入数据应用im2col之后,数据转化维2维矩阵。
使用im2col展开输入数据后,之后就只需将卷积层的卷积核纵向展开为1列,并计算2个矩阵的乘积即可。这和全连接层的Affine层进行的处理基本相同。
代码实现
讲了这么多,这里将给出代码实现
import numpy as np
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
"""
input_data : 由(数据量,通道,高,长)的4维数组构成的输入数据
filter_h : 滤波器的高
filter_w : 滤波器的长
stride : 步幅
pad : 填充
"""
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col
def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
N, C, H, W = input_shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
col = col.reshape(N, out_h, out_w, C, filter_h, filter_w).transpose(0, 3, 4, 5, 1, 2)
img = np.zeros((N, C, H + 2*pad + stride - 1, W + 2*pad + stride - 1))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]
return img[:, :, pad:H + pad, pad:W + pad]
class Convolution:
def __init__(self, W, b, stride=1, pad=0):
self.W = W
self.b = b
self.stride = stride
self.pad = pad
self.x = None
self.col = None
self.col_W = None
self.dW = None
self.db = None
def forward(self, x):
# [N, C, H, W]
FN, C, FH, FW = self.W.shape
N, C, H, W = x.shape
out_h = int(1 + (H + 2 * self.pad - FH) / self.stride)
out_w = int(1 + (W + 2 * self.pad - FW) / self.stride)
col = im2col(x, FH, FW, self.stride, self.pad)
col_W = self.W.reshape(FN, -1).T # 滤波器展开
out = np.dot(col, col_W) + self.b
out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
self.x = x
self.col = col
self.col_W = col_W
return out
def backward(self, dout):
FN, C, FH, FW = self.W.shape
dout = dout.transpose(0, 2, 3, 1).reshape(-1, FN)
self.db = np.sum(dout, axis=0)
self.dW = np.dot(self.col.T, dout)
self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
dcol = np.dot(dout, self.col_W.T)
dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
return dx
class Pooling:
def __init__(self, pool_h, pool_w, stride=1, pad=0):
self.pool_h = pool_h
self.pool_w = pool_w
self.stride = stride
self.pad = pad
self.x = None
self.arg_max = None
def forward(self, x):
N, C, H, W = x.shape
out_h = int(1 + (H - self.pool_h) / self.stride)
out_w = int(1 + (W - self.pool_w) / self.stride)
col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
col = col.reshape(-1, self.pool_h*self.pool_w)
arg_max = np.argmax(col, axis=1)
out = np.max(col, axis=1)
out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
self.x = x
self.arg_max = arg_max
return out
def backward(self, dout):
dout = dout.transpose(0, 2, 3, 1)
pool_size = self.pool_h * self.pool_w
dmax = np.zeros((dout.size, pool_size))
dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size,))
dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
return dx
小节
这篇文章断断续续地写了好久,中间还顺便在学tensorflow 2.0 ,还是框架用的舒服 orz。。。这几天还是决定把这篇文章写完,坑挖了还是得填,numpy手撕NN系列也算是暂时完成了,RNN后面再考虑。。。这之后准备再补补一些学过的算法的总结以及前段时间看的一些论文的总结。
本文首发于我的知乎
NN入门,手把手教你用Numpy手撕NN(三)的更多相关文章
- NN入门,手把手教你用Numpy手撕NN(一)
前言 这是一篇包含极少数学推导的NN入门文章 大概从今年4月份起就想着学一学NN,但是无奈平时时间不多,而且空闲时间都拿去做比赛或是看动漫去了,所以一拖再拖,直到这8月份才正式开始NN的学习. 这篇文 ...
- NN入门,手把手教你用Numpy手撕NN(2)
这是一篇包含较少数学推导的NN入门文章 上篇文章中简单介绍了如何手撕一个NN,但其中仍有可以改进的地方,将在这篇文章中进行完善. 误差反向传播 之前的NN计算梯度是利用数值微分法,虽容易实现,但是计算 ...
- 《手把手教你》系列技巧篇(三十八)-java+ selenium自动化测试-日历时间控件-下篇(详解教程)
1.简介 理想很丰满现实很骨感,在应用selenium实现web自动化时,经常会遇到处理日期控件点击问题,手工很简单,可以一个个点击日期控件选择需要的日期,但自动化执行过程中,完全复制手工这样的操作就 ...
- 《手把手教你》系列技巧篇(三十)-java+ selenium自动化测试- Actions的相关操作下篇(详解教程)
1.简介 本文主要介绍两个在测试过程中可能会用到的功能:Actions类中的拖拽操作和Actions类中的划取字段操作.例如:需要在一堆log字符中随机划取一段文字,然后右键选择摘取功能. 2.拖拽操 ...
- 《手把手教你》系列技巧篇(三十一)-java+ selenium自动化测试- Actions的相关操作-番外篇(详解教程)
1.简介 上一篇中,宏哥说的宏哥在最后提到网站的反爬虫机制,那么宏哥在自己本地做一个网页,没有那个反爬虫的机制,谷歌浏览器是不是就可以验证成功了,宏哥就想验证一下自己想法,于是写了这一篇文章,另外也是 ...
- 《手把手教你》系列技巧篇(三十二)-java+ selenium自动化测试-select 下拉框(详解教程)
1.简介 在实际自动化测试过程中,我们也避免不了会遇到下拉选择的测试,因此宏哥在这里直接分享和介绍一下,希望小伙伴或者童鞋们在以后工作中遇到可以有所帮助. 2.select 下拉框 2.1Select ...
- 《手把手教你》系列技巧篇(三十三)-java+ selenium自动化测试-单选和多选按钮操作-上篇(详解教程)
1.简介 在实际自动化测试过程中,我们同样也避免不了会遇到单选和多选的测试,特别是调查问卷或者是答题系统中会经常碰到.因此宏哥在这里直接分享和介绍一下,希望小伙伴或者童鞋们在以后工作中遇到可以有所帮助 ...
- 《手把手教你》系列技巧篇(三十四)-java+ selenium自动化测试-单选和多选按钮操作-中篇(详解教程)
1.简介 今天这一篇宏哥主要是讲解一下,如何使用list容器来遍历单选按钮.大致两部分内容:一部分是宏哥在本地弄的一个小demo,另一部分,宏哥是利用JQueryUI网站里的单选按钮进行实战. 2.d ...
- 《手把手教你》系列技巧篇(三十五)-java+ selenium自动化测试-单选和多选按钮操作-下篇(详解教程)
1.简介 今天这一篇宏哥主要是讲解一下,如何使用list容器来遍历多选按钮.大致两部分内容:一部分是宏哥在本地弄的一个小demo,另一部分,宏哥是利用JQueryUI网站里的多选按钮进行实战. 2.d ...
随机推荐
- The usage of Markdown---文字强调:加粗/斜体/文本高亮/删除线/下划线/按键效果
更新时间:2019.09.14 1. 序言 有时候,我们需要对某些文字进行强调,例如粗体和斜体.而Markdown通常可以使用星号*或者下划线_进行文字强调. 2. 加粗 如果想要达到加粗的效果,可以 ...
- 使用SQLserver Management Studio连接VS2012自带数据库
下载 Microsoft® SQL Server® 2008 Management Studio Express http://www.microsoft.com/zh-CN/download/det ...
- day20作业
1.下面这段代码的输出结果将是什么?请解释. class Parent(object): x = 1 class Child1(Parent): pass class Child2(Parent): ...
- unity www下载导致内存占用增加问题
服务端或者数据库更改导致客户端更改,最合理的处理方法是客户端时刻检测版本号(可以通过实时检测版本号),如果实时刷新数据库的数据开销比较大,尤其是有图片元素时. 采用unity www类下载时,虽然结束 ...
- Udp 异步通信(三)
转自:https://blog.csdn.net/zhujunxxxxx/article/details/44258719 1)服务端 using System; using System.Colle ...
- Java8系列 (四) 静态方法和默认方法
静态方法和默认方法 我们可以在 Comparator 接口的源码中, 看到大量类似下面这样的方法声明 //default关键字修饰的默认方法 default Comparator<T> t ...
- fiddler的过滤
1.User Fiters启用 2.Action Action:Run Filterset now是否运行,Load Filterset加载,Save Filterset保存: 3.Hosts过滤 Z ...
- C和C++引用传递和数组传参引用
引用传递有两种传参方式,具体可参考文章 概括地讲,就是 *声明一个形参是指针,所以需要传递指针实参,对应的函数实现也应当遵循指针的语法.这种实现思路并不针对于C或者C++,因为它们都有指针,所以都可以 ...
- Web for pentester_writeup之LDAP attacks篇
Web for pentester_writeup之LDAP attacks篇 LDAP attacks(LDAP 攻击) LDAP是轻量目录访问协议,英文全称是Lightweight Directo ...
- 【翻译】Prometheus 2.12.0 新特性
Prometheus 2.12.0 现在(2019.08.17)已经发布,在上个月的 2.11.0 之后又进行了一些修正和改进. 在当前的 6 周发布周期中,每一个 Prometheus 版本都有比较 ...