Caffe2 创建你的专属数据集(Create Your Own Dataset)[9]
这一节尝试把你的数据转换成caffe2能够使用的形式。这个教程使用Iris的数据集。你可以点击这里查看Ipython Notebook教程。
DB数据格式
Caffe2使用二进制的DB格式来保存数据。Caffe2 DB其实是键-值存储方式的一个美名而已。在键-值(key-value)存储方式里,键是随机生成的,所以batches是独立同分布的。而值(Value)则是真正的数据,他们包含着训练过程中真正用到的数据。所以,DB中保存的数据格式就像下面这样:
key1 value1 key2 value2 key3 value3 ...
在DB中,他把keys和values看成strings。你可以用TensorProtos protobuf来将你要保存的东西保存成DB数据结构。一个TensorProtos protobuf封装了Tensor(多维矩阵),和它的数据类型,形状信息。然后,你可以通过TensorProtosDBInput操作来载入数据到SGD训练过程中。
准备自己的数据
这里,我们向你展示如何创建自己的数据集。为此,我们将会使用UCI Iris数据集。这是一个非常受欢迎的经典的用于分类鸢尾花的数据集。它包含4个代表花的外形特征的实数。这个数据集包含3种鸢尾花。你可以从这里下载数据集。
%matplotlib inline
import urllib2 # 用于从网上下载数据集
import numpy as np
from matplotlib import pyplot
from StringIO import StringIO
from caffe2.python import core, utils, workspace
from caffe2.proto import caffe2_pb2
WARNING:root:This caffe2 python run does not have GPU support. Will run in CPU only mode.
WARNING:root:Debug message: No module named caffe2_pybind11_state_gpu
#如果你在Mac OS下使用homebrew,你可能会遇到一个错误: malloc_zone_unregister() 函数失败.这不是Caffe2的问题,而是因为 homebrew leveldb 的内存分配不兼容. 但这不影响使用。
f = urllib2.urlopen('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data')
raw_data = f.read()
print('Raw data looks like this:')
print(raw_data[:100] + '...')
输出:
Raw data looks like this:
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,...
#将特征保存到一个特征矩阵
features = np.loadtxt(StringIO(raw_data), dtype=np.float32, delimiter=',', usecols=(0, 1, 2, 3))
#把label存到一个特征矩阵中
label_converter = lambda s : {'Iris-setosa':0, 'Iris-versicolor':1, 'Iris-virginica':2}[s]
labels = np.loadtxt(StringIO(raw_data), dtype=np.int, delimiter=',', usecols=(4,), converters={4: label_converter})
在我们开始训练之前,最好将数据集分成训练集和测试集。在这个例子中,让我们随机打乱数据,用前100个数据做训练集,剩余50个数据做测试。当然你也可以用更加复杂的方式,例如使用交叉校验的方式将数据集分成多个训练集和测试集。关于交叉校验的更多信息,请看这里。
random_index = np.random.permutation(150)
features = features[random_index]
labels = labels[random_index]
train_features = features[:100]
train_labels = labels[:100]
test_features = features[100:]
test_labels = labels[100:]
legend = ['rx', 'b+', 'go']
pyplot.title("Training data distribution, feature 0 and 1")
for i in range(3):
pyplot.plot(train_features[train_labels==i, 0], train_features[train_labels==i, 1], legend[i])
pyplot.figure()
pyplot.title("Testing data distribution, feature 0 and 1")
for i in range(3):
pyplot.plot(test_features[test_labels==i, 0], test_features[test_labels==i, 1], legend[i])
现在,把数据放进Caffe2的DB中去。在这个DB中,我们将会使用train_xxx
作为key,并对于每一个点使用一个TensorProtos对象去储存,一个TensorProtos包含两个tensor:一个是特征,一个是label。我们使用Caffe2的Python DB接口。
# 构建一个TensorProtos protobuf
feature_and_label = caffe2_pb2.TensorProtos()
feature_and_label.protos.extend([
utils.NumpyArrayToCaffe2Tensor(features[0]),
utils.NumpyArrayToCaffe2Tensor(labels[0])])
print('This is what the tensor proto looks like for a feature and its label:')
print(str(feature_and_label))
print('This is the compact string that gets written into the db:')
print(feature_and_label.SerializeToString())
This is what the tensor proto looks like for a feature and its label:
protos {
dims: 4
data_type: FLOAT
float_data: 5.40000009537
float_data: 3.0
float_data: 4.5
float_data: 1.5
}
protos {
data_type: INT32
int32_data: 1
}
This is the compact string that gets written into the db:
�̬@@@�@�?
"
现在真正写入DB中去
def write_db(db_type, db_name, features, labels):
db = core.C.create_db(db_type, db_name, core.C.Mode.write)
transaction = db.new_transaction()
for i in range(features.shape[0]):
feature_and_label = caffe2_pb2.TensorProtos()
feature_and_label.protos.extend([
utils.NumpyArrayToCaffe2Tensor(features[i]),
utils.NumpyArrayToCaffe2Tensor(labels[i])])
transaction.put(
'train_%03d'.format(i),
feature_and_label.SerializeToString())
# Close the transaction, and then close the db.
del transaction
del db
write_db("minidb", "iris_train.minidb", train_features, train_labels)
write_db("minidb", "iris_test.minidb", test_features, test_labels)
现在让我恩创建一个简单的网络,这个网络只包含一个简单的TensorProtosDBInput 操作,用来展示我们如何从创建好的DB中读入数据。
net_proto = core.Net("example_reader")
dbreader = net_proto.CreateDB([], "dbreader", db="iris_train.minidb", db_type="minidb")
net_proto.TensorProtosDBInput([dbreader], ["X", "Y"], batch_size=16)
print("The net looks like this:")
print(str(net_proto.Proto()))
The net looks like this:
name: "example_reader"
op {
output: "dbreader"
name: ""
type: "CreateDB"
arg {
name: "db_type"
s: "minidb"
}
arg {
name: "db"
s: "iris_train.minidb"
}
}
op {
input: "dbreader"
output: "X"
output: "Y"
name: ""
type: "TensorProtosDBInput"
arg {
name: "batch_size"
i: 16
}
}
创建网络
workspace.CreateNet(net_proto)
# 先跑一次,然后获取里面的数据
workspace.RunNet(net_proto.Proto().name)
print("The first batch of feature is:")
print(workspace.FetchBlob("X"))
print("The first batch of label is:")
print(workspace.FetchBlob("Y"))
# 再跑一次
workspace.RunNet(net_proto.Proto().name)
print("The second batch of feature is:")
print(workspace.FetchBlob("X"))
print("The second batch of label is:")
print(workspace.FetchBlob("Y"))
The first batch of feature is:
[[ 5.19999981 4.0999999 1.5 0.1 ]
[ 5.0999999 3.79999995 1.5 0.30000001]
[ 6.9000001 3.0999999 4.9000001 1.5 ]
[ 7.69999981 2.79999995 6.69999981 2. ]
[ 6.5999999 2.9000001 4.5999999 1.29999995]
[ 6.30000019 2.79999995 5.0999999 1.5 ]
[ 7.30000019 2.9000001 6.30000019 1.79999995]
[ 5.5999999 2.9000001 3.5999999 1.29999995]
[ 6.5 3. 5.19999981 2. ]
[ 5. 3.4000001 1.5 0.2 ]
[ 6.9000001 3.0999999 5.4000001 2.0999999 ]
[ 6. 3.4000001 4.5 1.60000002]
[ 5.4000001 3.4000001 1.70000005 0.2 ]
[ 6.30000019 2.70000005 4.9000001 1.79999995]
[ 5.19999981 2.70000005 3.9000001 1.39999998]
[ 6.19999981 2.9000001 4.30000019 1.29999995]]
The first batch of label is:
[0 0 1 2 1 2 2 1 2 0 2 1 0 2 1 1]
The second batch of feature is:
[[ 5.69999981 2.79999995 4.0999999 1.29999995]
[ 5.0999999 2.5 3. 1.10000002]
[ 4.4000001 2.9000001 1.39999998 0.2 ]
[ 7. 3.20000005 4.69999981 1.39999998]
[ 5.69999981 2.9000001 4.19999981 1.29999995]
[ 5. 3.5999999 1.39999998 0.2 ]
[ 5.19999981 3.5 1.5 0.2 ]
[ 6.69999981 3. 5.19999981 2.29999995]
[ 6.19999981 3.4000001 5.4000001 2.29999995]
[ 6.4000001 2.70000005 5.30000019 1.89999998]
[ 6.5 3.20000005 5.0999999 2. ]
[ 6.0999999 3. 4.9000001 1.79999995]
[ 5.4000001 3.4000001 1.5 0.40000001]
[ 4.9000001 3.0999999 1.5 0.1 ]
[ 5.5 3.5 1.29999995 0.2 ]
[ 6.69999981 3. 5. 1.70000005]]
The second batch of label is:
[1 1 0 1 1 0 0 2 2 2 2 2 0 0 0 1]
至此,本节教程结束。
转载请注明出处:http://www.jianshu.com/c/cf07b31bb5f2
Caffe2 创建你的专属数据集(Create Your Own Dataset)[9]的更多相关文章
- 用ArcMap在PostgreSQL中创建要素类需要执行”create enterprise geodatabase”吗
问:用Acmap在PostgreSQL中创建要素类需要执行"create enterprise geodatabase"吗? 关于这个问题,是在为新员工做postgresql培训后 ...
- 使用C#创建计划任务(How to create a Task Scheduler use C# )
本文主要讲解了如何使用C#来创建windows计划任务. 需求:在不定时间段运行多个后台程序(winfrom,wpf,console,等等)用于更新数据. 问题:为什么要使用计划任务,而不直接在程序 ...
- PHP MySQL 创建数据库和表 之 Create
创建数据库 CREATE DATABASE 语句用于在 MySQL 中创建数据库. 语法 CREATE DATABASE database_name 为了让 PHP 执行上面的语句,我们必须使用 my ...
- PIE创建带压缩的栅格数据集
这段时间我一直在研究如何用PIE创建带压缩的栅格数据集,由于我在比赛中使用的原始影像大小普遍都在300M以上,软件加载较慢,因此希望能对原始影像进行压缩,加快加载时间. 首先,该方法的关键是修改Dat ...
- Tensorflow创建和读取17flowers数据集
http://blog.csdn.net/sinat_16823063/article/details/53946549 Tensorflow创建和读取17flowers数据集 标签: tensorf ...
- 目标检测数据集The Object Detection Dataset
目标检测数据集The Object Detection Dataset 在目标检测领域,没有像MNIST或Fashion MNIST这样的小数据集.为了快速测试模型,我们将组装一个小数据集.首先,我们 ...
- Caffe2 手写字符识别(MNIST - Create a CNN from Scratch)[8]
本教程创建一个小的神经网络用于手写字符的识别.我们使用MNIST数据集进行训练和测试.这个数据集的训练集包含60000张来自500个人的手写字符的图像,测试集包含10000张独立于训练集的测试图像.你 ...
- ArcGIS 网络分析[8.4] 资料4 聚合——创建及打开网络数据集的类实现
这篇是对前三篇的总结,因为网络数据集涉及的"点"太多了,我只能挑重点来设置,大家明白框架后可以自行寻求帮助文档添加功能. 我以C#类的形式给出,这个类包含很多种方法,因为本人的C# ...
- MySQL中临时表的基本创建与使用教程(create temporary table )
当工作在非常大的表上时,你可能偶尔需要运行很多查询获得一个大量数据的小的子集,不是对整个表运行这些查询,而是让MySQL每次找出所需的少数记录,将记录选择到一个临时表可能更快些,然后在这些表运行查询. ...
随机推荐
- Atcoder Beginner Contest153F(模拟)
应该也可以用线段树/树状数组区间更新怪兽的生命值来做 #define HAVE_STRUCT_TIMESPEC #include<bits/stdc++.h> using namespac ...
- mybatis会自动把字段名中的下划线转为驼峰命名法?
先看一下转化的调用堆栈: 代码如下: 上面代码只是去掉了下划线,并没有首字母小写变大写的代码.再跟进findProperty方法可以找到获取驼峰结果的代码如下: 可以看出通过reflector.fin ...
- js相关--浅拷贝和深拷贝
1.js的数据类型 基本概述:js的数据类型分为两种,分别为基本数据类型和引用数据类型,它们俩的区别在于基本数据类型采用值传递,引用数据类型采用指针形式传递. 如下所示:引用类型通过简单的=进行复制, ...
- sqlmap 扫描注入漏洞
.检测是否存在漏洞: sqlmap -u .获取数据库信息: sqlmap -u --dbs .数据库表信息: sqlmap -u -D security --tables .表中字段信息 sqlma ...
- 关于websockets的压测工具
这是在workerman群中得到的信息,记录在此: loadrunner jemeter
- TD tree体验
在体验了学长们设计的app后,我颇有感触,我们也可以凭借自己的力量来开发一款软件,虽然它可能并不如市面上相同类型的那么完美,但它对我们的意义却是不一样的. 我是在下午的见面会上看到的这款软件,接待的学 ...
- 使用axios对安卓或者ios低版本兼容性处理
原因:不支持ES6,无法使用promise 解决办法: 1.安装 es6-promise cnpm install es6-promise --save-dev 2.引入 es6-promise im ...
- JAVA单例实现方式(常用)
JAVA单例实现方式(常用) public class Singleton { // Q1:为什么要使用volatile关键字? private volatile static Singleton u ...
- idea新建项目相关名词意义
新建项目中的对比 建完之后的项目对比 对比 新建中Artifact的名称对应maven中名字 新建中package的名字对应的是项目中src下package名字 新建中project name的名字对 ...
- 创业学习---今日头条创业过程分析---HHR计划
本文搜集和整理了今日头条创业的一些关键点的资料------by 春跃(本文的主要观点都是搜集整理,所以不得本人同意不得转载) 一,18年之前的今日头条创业时间表: 1,张一鸣参与创业的履历:酷讯,饭否 ...