一、说明

SIFT Flow 是一个标注的语义分割的数据集,有两个label,一个是语义分类(33类),另一个是场景标签(3类)。

Semantic and geometric segmentation classes for scenes.

Semantic: 0 is void and 1–33 are classes.

01 awning
02 balcony
03 bird
04 boat
05 bridge
06 building
07 bus
08 car
09 cow
10 crosswalk
11 desert
12 door
13 fence
14 field
15 grass
16 moon
17 mountain
18 person
19 plant
20 pole
21 river
22 road
23 rock
24 sand
25 sea
26 sidewalk
27 sign
28 sky
29 staircase
30 streetlight
31 sun
32 tree
33 window Geometric: -1 is void and 1–3 are classes. 01 sky
02 horizontal
03 vertical

二、模型训练

1、源码下载

git clone git@github.com:shelhamer/fcn.berkeleyvision.org.git

2、数据准备

下载标注好的SiftFlowDataset.zip数据集,地址:http://www.cs.unc.edu/~jtighe/Papers/ECCV10/siftflow/SiftFlowDataset.zip

将压缩包解压至data/sift-flow文件夹下。

3、代码修改

git clone git@github.com:litingpan/fcn.git

或从https://github.com/litingpan/fcn 下载,替换掉siftflow-fcn32s整个文件夹。

其中solve.py修改如下:

import caffe
import surgery, score import numpy as np
import os
import sys try:
import setproctitle
setproctitle.setproctitle(os.path.basename(os.getcwd()))
except:
pass # weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'
vgg_weights = '../ilsvrc-nets/VGG_ILSVRC_16_layers.caffemodel'
vgg_proto = '../ilsvrc-nets/VGG_ILSVRC_16_layers_deploy.prototxt' # init
# caffe.set_device(int(sys.argv[1]))
caffe.set_device(0)
caffe.set_mode_gpu() # solver = caffe.SGDSolver('solver.prototxt')
# solver.net.copy_from(weights)
solver = caffe.SGDSolver('solver.prototxt')
vgg_net = caffe.Net(vgg_proto, vgg_weights, caffe.TRAIN)
surgery.transplant(solver.net, vgg_net)
del vgg_net # surgeries
interp_layers = [k for k in solver.net.params.keys() if 'up' in k]
surgery.interp(solver.net, interp_layers) # scoring
test = np.loadtxt('../data/sift-flow/test.txt', dtype=str) for _ in range(50):
solver.step(2000)
# N.B. metrics on the semantic labels are off b.c. of missing classes;
# score manually from the histogram instead for proper evaluation
score.seg_tests(solver, False, test, layer='score_sem', gt='sem')
score.seg_tests(solver, False, test, layer='score_geo', gt='geo')

4、下载预训练模型

Revisions · ILSVRC-2014 model (VGG team) with 16 weight layers  https://gist.github.com/ksimonyan/211839e770f7b538e2d8/revisions

同时下载VGG_ILSVRC_16_layers.caffemodel和VGG_ILSVRC_16_layers_deploy.prototxt放在ilsvrc-nets目录下

5、训练

python solve.py

训练完成后,在snapshot目录下train_iter_100000.caffemodel即为训练好的模型。

三、预测

1、模型准备

可以使用我们前面训练好的模型,如果不想自己训练,则可以直接下载训练好的模型http://dl.caffe.berkeleyvision.org/siftflow-fcn32s-heavy.caffemodel

2、deploy.prototxt

由test.prototxt修改过来的,主要修改了有三个地方,

(1)输入层

layer {
name: "input"
type: "Input"
top: "data"
input_param {
# These dimensions are purely for sake of example;
# see infer.py for how to reshape the net to the given input size.
shape { dim: 1 dim: 3 dim: 256 dim: 256 }
}
}

注意Input中,要与被测图片的尺寸一致。

(2)删掉了drop层

(3)删除了含有loss层相关层

3、infer.py

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import sys
import caffe # the demo image is "2007_000129" from PASCAL VOC # load image, switch to BGR, subtract mean, and make dims C x H x W for Caffe
im = Image.open('coast_bea14.jpg')
in_ = np.array(im, dtype=np.float32)
in_ = in_[:,:,::-1]
in_ -= np.array((104.00698793,116.66876762,122.67891434))
in_ = in_.transpose((2,0,1)) # load net
net = caffe.Net('deploy.prototxt', 'snapshot/train_iter_100000.caffemodel', caffe.TEST)
# shape for input (data blob is N x C x H x W), set data
net.blobs['data'].reshape(1, *in_.shape)
net.blobs['data'].data[...] = in_
# run net and take argmax for prediction
net.forward()
sem_out = net.blobs['score_sem'].data[0].argmax(axis=0) # plt.imshow(out,cmap='gray');
plt.imshow(sem_out)
plt.axis('off')
plt.savefig('coast_bea14_sem_out.png')
sem_out_img = Image.fromarray(sem_out.astype('uint8')).convert('RGB')
sem_out_img.save('coast_bea14_sem_img_out.png') geo_out = net.blobs['score_geo'].data[0].argmax(axis=0)
plt.imshow(geo_out)
plt.axis('off')
plt.savefig('coast_bea14_geo_out.png')
geo_out_img = Image.fromarray(geo_out.astype('uint8')).convert('RGB')
geo_out_img.save('coast_bea14_geo_img_out.png')

其中,sem_out_img保存着语义分割的结果,geo_out_img保存场景标识的结果。

4、测试

python infer.py

Sift-flow中的图片都为256*256*3的彩色图片

images保存的是数据,semanticlabels保存的是语义分割标签,一共33类(而标注的数据会多一个无效类)。geolabels保存场景识别标签,共3类(而标注的数据会多一个无效类)。

所以是分别训练了两个网络,网络的前七层一样。

其中coast_bea14_sem_out.png为语义分割的结果, coast_bea14_geo_out.png为场景标识的结果,

原图                                                  语义分割                                                 场景标识

end

siftflow-fcn32s训练及预测的更多相关文章

  1. fcn训练及预测tgs数据集

    一.背景 kaggle上有这样一个题目,关于盐份预测的语义分割题目.TGS Salt Identification Challenge | Kaggle  https://www.kaggle.com ...

  2. 机器学习使用sklearn进行模型训练、预测和评价

    cross_val_score(model_name, x_samples, y_labels, cv=k) 作用:验证某个模型在某个训练集上的稳定性,输出k个预测精度. K折交叉验证(k-fold) ...

  3. 初识Sklearn-IrisData训练与预测

    笔记:机器学习入门---鸢尾花分类 Sklearn 本身就有很多数据库,可以用来练习. 以 Iris 的数据为例,这种花有四个属性,花瓣的长宽,茎的长宽,根据这些属性把花分为三类:山鸢尾花Setosa ...

  4. Spark技术在京东智能供应链预测的应用——按照业务进行划分,然后利用scikit learn进行单机训练并预测

    3.3 Spark在预测核心层的应用 我们使用Spark SQL和Spark RDD相结合的方式来编写程序,对于一般的数据处理,我们使用Spark的方式与其他无异,但是对于模型训练.预测这些需要调用算 ...

  5. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

  6. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  7. GUI:GUI的方式创建/训练/仿真/预测神经网络—Jason niu

    (1)导入数据:点击最左底部Import 按钮 (2)创建模型network_Jason_niu:点击底部的New按钮 (3)设置参数并训练:点击底部的Open按钮 (4)仿真预测: 大功告成!

  8. TensorFlow 1.4利用Keras+Estimator API进行训练和预测

    Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中 ...

  9. 用C实现单隐层神经网络的训练和预测(手写BP算法)

    实验要求:•实现10以内的非负双精度浮点数加法,例如输入4.99和5.70,能够预测输出为10.69•使用Gprof测试代码热度 代码框架•随机初始化1000对数值在0~10之间的浮点数,保存在二维数 ...

随机推荐

  1. jquery tmpl生成导航

    引入<script src="jquery.tmpl.min.js"></script> html<ul class="nav" ...

  2. RobotFramework Selenium2 关键字

    *** Settings ***Library Selenium2Library *** Keywords ***Checkbox应该不被选择 [Arguments] ${locator} Check ...

  3. linux之时间设置

    date 显示与设置系统时间 %Y      year %m moth 月 %d day 日期 %H hour 小时 %M      minute   分钟 %S      sec  秒 +%F    ...

  4. python学习-迭代器,列表解析和列表生成式

    迭代器为类序列对象提供了一个类序列的接口.Python 的迭代无缝的支持序列对象,而且还允许程序猿迭代非序列类型,包括用户定义的对象. 迭代器是一个next()方法的对象,而不是通过索引计数.当需要下 ...

  5. c++利用类进行单链表的插入,删除,清空操作

    #if 1 #include <iostream> #include <stdlib.h> #include <time.h> #include <fstre ...

  6. SpringBoot中的ajax跨域问题

    在控制类加入注释@CrossOrigin(allowCredentials = "true",allowedHeaders = "*",origins = {& ...

  7. MS SQL 全局临时表的删除

    本来已经搜索到怎么删除了 如下: IF OBJECT_ID( 'tempdb..##TEMP_COPTD') IS NOT NULL Begin DROP TABLE ##TEMP_COPTD End ...

  8. 日志管理中获取浏览器、操作系统、IP等信息。。。

    今天在书写日志管理的模块的时候,遇到了一些问题,首先是日志的添加,就是在登录的时候记下他登录的名字以及登录的时间和登录的一些信息给存入到日志表中,这一下给蒙了,于是就查找资源,在这里我就简单地总结一下 ...

  9. win7 安装用mingw编译的Qt源码并连接postgresql

    下载Qt 1.下载qt-creator-windows-opensource-2.8.0,下载路径:http://download.qt.io/official_releases/qtcreator/ ...

  10. 从软件测试转型到C#上位机程序员

    一直在做软件测试的工作,天天与程序员不依不饶的争论细节的问题,没想到自己也有那么一天走上程序员的道路,由此开始,我的博客天天更新自己的学习状态,分享自己的心得. C#是微软公司发布的一种面向对象的.运 ...