Caffe2的相关概念

接下来你可以学到更多Caffe2中主要的概念,这些概念对理解和开发Caffe2相当重要。

Blobs and Workspace,Tensors

Caffe2中,数据是用blobs储存的。Blob只是内存中的一个数据块。大多数Blobs包含一个张量(tensor),可以理解为多维矩阵,在Python中,他们被转换为numpy 矩阵。

Workspace 保存着所有的Blobs。下面的例子展示了如何向Workspace中传递Blobs和取出他们。Workspace在你开始使用他们时,才进行初始化。

  1. # Create random tensor of three dimensions
  2. x = np.random.rand(4, 3, 2)
  3. print(x)
  4. print(x.shape)
  5. workspace.FeedBlob("my_x", x)
  6. x2 = workspace.FetchBlob("my_x")
  7. print(x2)

Nets and Operators

Caffe2中最基本的对象是netnet可以说是一系列Operators的集合,每个Operator根据输入的blob输出一个或者多个blob

  下面我们将会创建一个超级简单的模型。他拥有如下部件:

  • 一个全连接层
  • 一个Sigmoid激活函数和一个Softmax函数
  • 一个交叉损失

      直接构建网络是很厌烦的,所以最好使用Python接口的模型助手来构建网络。我们只需简单的调用CNNModelHelper,他就会帮我们创建两个想联系的网络。
  • 一个用于初始化参数(ref.init_net
  • 一个用于实际训练(ref.init_net
  1. # Create the input data
  2. data = np.random.rand(16, 100).astype(np.float32)
  3. # Create labels for the data as integers [0, 9].
  4. label = (np.random.rand(16) * 10).astype(np.int32)
  5. workspace.FeedBlob("data", data)
  6. workspace.FeedBlob("label", label)
  7. # Create model using a model helper
  8. m = cnn.CNNModelHelper(name="my first net")
  9. fc_1 = m.FC("data", "fc1", dim_in=100, dim_out=10)
  10. pred = m.Sigmoid(fc_1, "pred")
  11. [softmax, loss] = m.SoftmaxWithLoss([pred, "label"], ["softmax", "loss"])

上面的代码中,我们首先在内存中创建了输入数据和标签,实际使用中,往往从database等载体中读入数据。可以看到输入数据和标签的第一维度是16,这是因为输入的最小batch最小是16。Caffe2中很多Operator都能直接通过CNNModelHelper来进行,并且能够一次处理一个batchCNNModelHelper’s Operator List中有更详细的解析。

  第二,我们通过一些操作创建了一个模型。比如FCSigmoidSoftmaxWithLoss 注意:这个时候,这些操作并没有真正执行,他们仅仅是对模型进行了定义。

  模型助手创建了两个网络:m.param_init_net,这个网络将仅仅被执行一次。他将会初始化参数blob,例如全连接层的权重。真正的训练是通过执行m.net来是现实的。这是自动发生的。

  网络的定义保存在一个protobuf结构体中。你可以很容易的通过调用net.proto来查看它。

  1. print(str(m.net.Proto()))

输出如下:

  1. name: "my first net"
  2. op {
  3. input: "data"
  4. input: "fc1_w"
  5. input: "fc1_b"
  6. output: "fc1"
  7. name: ""
  8. type: "FC"
  9. }
  10. op {
  11. input: "fc1"
  12. output: "pred"
  13. name: ""
  14. type: "Sigmoid"
  15. }
  16. op {
  17. input: "pred"
  18. input: "label"
  19. output: "softmax"
  20. output: "loss"
  21. name: ""
  22. type: "SoftmaxWithLoss"
  23. }
  24. external_input: "data"
  25. external_input: "fc1_w"
  26. external_input: "fc1_b"
  27. external_input: "label"

同时,你也可以查看参数初始化网络:

  1. print(str(m.param_init_net.Proto()))

这就是Caffe2的API:使用Python接口方便快速的构建网络并训练你的模型,Python接口将这些网络通过序列化的protobuf传递给C++接口,然后C++接口全力的执行。

Executing

现在我们可以开始训练我们的模型。

  首先,我们先跑一次参数初始化网络。

  1. workspace.RunNetOnce(m.param_init_net)

这个操作将会把param_init_netprotobuf传递给C++代码进行执行。

然后我们真正的创建网络

  1. workspace.CreateNet(m.net)

一旦创建好网络,我们就可以高效的跑起来:

  1. # Run 100 x 10 iterations 跑100*10次迭代
  2. for j in range(0, 100):
  3. data = np.random.rand(16, 100).astype(np.float32)
  4. label = (np.random.rand(16) * 10).astype(np.int32)
  5. workspace.FeedBlob("data", data)
  6. workspace.FeedBlob("label", label)
  7. workspace.RunNet(m.name, 10) # run for 10 times 跑十次

这里要注意的是我们怎样在RunNet()函数中使用网络的名字。并且在这里,由于网络已经在workspace中创建,所以我们不需要再传递网络的定义。执行完后,你可以查看存在输出blob中的结果。

  1. print(workspace.FetchBlob("softmax"))
  2. print(workspace.FetchBlob("loss"))

Backward pass

上面的网络中,仅仅包含了网络的前向传播,因此它是学习不到任何东西的。后向传播对每一个前向传播进行gradient operator。如果你想自己尝试这样的操作,那么你可以进行以下操作并检查结果。

RunNetOnce(),插入下面操作:

  1. m.AddGradientOperators([loss])

然后测试protobuf的输出:

  1. print(str(m.net.Proto()))

以上就是大体的使用教程

译者注

训练过程可以总结为以下步骤:

  1. # Create model using a model helper
  2. m = cnn.CNNModelHelper(name="my first net")
  3. fc_1 = m.FC("data", "fc1", dim_in=100, dim_out=10)
  4. pred = m.Sigmoid(fc_1, "pred")
  5. [softmax, loss] = m.SoftmaxWithLoss([pred, "label"], ["softmax", "loss"])
  6. m.AddGradientOperators([loss]) #注意这一行代码
  7. workspace.RunNetOnce(m.param_init_net)
  8. workspace.CreateNet(m.net)
  9. # Run 100 x 10 iterations
  10. for j in range(0, 100):
  11. data = np.random.rand(16, 100).astype(np.float32)
  12. label = (np.random.rand(16) * 10).astype(np.int32)
  13. workspace.FeedBlob("data", data)
  14. workspace.FeedBlob("label", label)
  15. workspace.RunNet(m.name, 10) # run for 10 times

结语:

转载请注明出处:http://www.jianshu.com/c/cf07b31bb5f2

Caffe2 手册(Intro Tutorial)[2]的更多相关文章

  1. Caffe2 Tutorials[0]

    本系列教程包括9个小节,对应Caffe2官网的前9个教程,第10个教程讲的是在安卓下用SqueezeNet进行物体检测,此处不再翻译.另外由于栏主不关注RNN和LSTM,所以栏主不对剩下两个教程翻译. ...

  2. linux下scrapy环境搭建

    最近使用scrapy做数据挖掘,使用scrapy定时抓取数据并存入MongoDB,本文记录环境搭建过程以作备忘 OS:ubuntu 14.04  python:2.7.6 scrapy:1.0.5 D ...

  3. Scrapy使用详细记录

    这几天,又用到了scrapy框架写爬虫,感觉忘得差不多了,虽然保存了书签,但有些东西,还是多写写才好啊 首先,官方而经典的的开发手册那是需要的: https://doc.scrapy.org/en/l ...

  4. 学python,怎么能不学习scrapy呢!

    摘要:本文讲述如何编写scrapy爬虫. 本文分享自华为云社区<学python,怎么能不学习scrapy呢,这篇博客带你学会它>,作者: 梦想橡皮擦 . 在正式编写爬虫案例前,先对 scr ...

  5. Scrapy开发指南

    一.Scrapy简介 Scrapy是一个为了爬取网站数据,提取结构性数据而编写的应用框架. 可以应用在包括数据挖掘,信息处理或存储历史数据等一系列的程序中. Scrapy基于事件驱动网络框架 Twis ...

  6. [转]python 常用类库!

    Python学习 On this page... (hide) 1. 基本安装 2. Python文档 2.1 推荐资源站点 2.2 其他参考资料 2.3 代码示例 3. 常用工具 3.1 Pytho ...

  7. Scrapy003-项目流程

    Scrapy003-项目流程 @(Spider)[POSTS] 前两篇文章我们了解到Scrapy的原理和安装的相关知识,这节就需要知道创建项目流程的小知识. 根据官方文档:http://scrapy- ...

  8. Python爬虫Scrapy框架入门(0)

    想学习爬虫,又想了解python语言,有个python高手推荐我看看scrapy. scrapy是一个python爬虫框架,据说很灵活,网上介绍该框架的信息很多,此处不再赘述.专心记录我自己遇到的问题 ...

  9. (转) Deep Learning Research Review Week 2: Reinforcement Learning

      Deep Learning Research Review Week 2: Reinforcement Learning 转载自: https://adeshpande3.github.io/ad ...

随机推荐

  1. boolean类型set、get方法

    今天在了解lombok的时候偶然看到一个问题,在bean中存在boolean类型的数据的时候,用eclipse工具自动生成的set.get方法存在的问题. 不管变量为isXXX还是XXX时,set.g ...

  2. Java中List集合的逆序排列

    Collections.reverse(list);  //实现List集合逆序排列

  3. esp8266(wifi)模块调试记录

    1.要注意usb转TTL接口上的晶振 如果晶振是12Mhz,可能就收不到反馈,因为12Mhz波特率会有误差.

  4. word中如何删除一张空白表格

    百度知道:https://baijiahao.baidu.com/s?id=1631677477148377412&wfr=spider&for=pc 当word中出现如下一张空白表格 ...

  5. Java 倒入文章显示前n个单词频率

    package com_1; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOExc ...

  6. 2016 Google code jam 大赛

    二,RoundC import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileNotFoundE ...

  7. ZOJ - 3203 Light Bulb(三分)

    题意:灯离地面的高度为$H$,人的身高为$h$,灯离墙的距离为$D$,人站在不同位置,影子的长度不一样,求出影子的最长长度. 思路:设人离灯的距离为$x$,当人走到距离灯长度为$L$时,人在墙上的影子 ...

  8. java用JSONObject生成json

    Json在前后台传输中,是使用最多的一种数据类型.json生成的方法有很多,自己只是很皮毛的知道点,用的时候,难免会蒙.现在整理下 第一种: import net.sf.json.JSONArray; ...

  9. java redis 实现用户签到功能(很普通简单的签到功能)

    业务需求是用户每天只能签到一次,而且签到后用户增加积分,所以把用户每次签到时放到redis 缓存里面,然后每天凌晨时再清除缓存,大概简单思想是这样的 直接看代码吧如下 @Transactional @ ...

  10. django.db.utils.OperationalError: (2003, "Can't connect to MySQL server on ‘127.0.0.1’)

    报错信息如下: 检查发现原来是自己的sql没有启动 启动mysql后,