一、说明

SIFT Flow 是一个标注的语义分割的数据集,有两个label,一个是语义分类(33类),另一个是场景标签(3类)。

Semantic and geometric segmentation classes for scenes.

Semantic: 0 is void and 1–33 are classes.

01 awning
02 balcony
03 bird
04 boat
05 bridge
06 building
07 bus
08 car
09 cow
10 crosswalk
11 desert
12 door
13 fence
14 field
15 grass
16 moon
17 mountain
18 person
19 plant
20 pole
21 river
22 road
23 rock
24 sand
25 sea
26 sidewalk
27 sign
28 sky
29 staircase
30 streetlight
31 sun
32 tree
33 window Geometric: -1 is void and 1–3 are classes. 01 sky
02 horizontal
03 vertical

二、模型训练

1、源码下载

git clone git@github.com:shelhamer/fcn.berkeleyvision.org.git

2、数据准备

下载标注好的SiftFlowDataset.zip数据集,地址:http://www.cs.unc.edu/~jtighe/Papers/ECCV10/siftflow/SiftFlowDataset.zip

将压缩包解压至data/sift-flow文件夹下。

3、代码修改

git clone git@github.com:litingpan/fcn.git

或从https://github.com/litingpan/fcn 下载,替换掉siftflow-fcn32s整个文件夹。

其中solve.py修改如下:

import caffe
import surgery, score import numpy as np
import os
import sys try:
import setproctitle
setproctitle.setproctitle(os.path.basename(os.getcwd()))
except:
pass # weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'
vgg_weights = '../ilsvrc-nets/VGG_ILSVRC_16_layers.caffemodel'
vgg_proto = '../ilsvrc-nets/VGG_ILSVRC_16_layers_deploy.prototxt' # init
# caffe.set_device(int(sys.argv[1]))
caffe.set_device(0)
caffe.set_mode_gpu() # solver = caffe.SGDSolver('solver.prototxt')
# solver.net.copy_from(weights)
solver = caffe.SGDSolver('solver.prototxt')
vgg_net = caffe.Net(vgg_proto, vgg_weights, caffe.TRAIN)
surgery.transplant(solver.net, vgg_net)
del vgg_net # surgeries
interp_layers = [k for k in solver.net.params.keys() if 'up' in k]
surgery.interp(solver.net, interp_layers) # scoring
test = np.loadtxt('../data/sift-flow/test.txt', dtype=str) for _ in range(50):
solver.step(2000)
# N.B. metrics on the semantic labels are off b.c. of missing classes;
# score manually from the histogram instead for proper evaluation
score.seg_tests(solver, False, test, layer='score_sem', gt='sem')
score.seg_tests(solver, False, test, layer='score_geo', gt='geo')

4、下载预训练模型

Revisions · ILSVRC-2014 model (VGG team) with 16 weight layers  https://gist.github.com/ksimonyan/211839e770f7b538e2d8/revisions

同时下载VGG_ILSVRC_16_layers.caffemodel和VGG_ILSVRC_16_layers_deploy.prototxt放在ilsvrc-nets目录下

5、训练

python solve.py

训练完成后,在snapshot目录下train_iter_100000.caffemodel即为训练好的模型。

三、预测

1、模型准备

可以使用我们前面训练好的模型,如果不想自己训练,则可以直接下载训练好的模型http://dl.caffe.berkeleyvision.org/siftflow-fcn32s-heavy.caffemodel

2、deploy.prototxt

由test.prototxt修改过来的,主要修改了有三个地方,

(1)输入层

layer {
name: "input"
type: "Input"
top: "data"
input_param {
# These dimensions are purely for sake of example;
# see infer.py for how to reshape the net to the given input size.
shape { dim: 1 dim: 3 dim: 256 dim: 256 }
}
}

注意Input中,要与被测图片的尺寸一致。

(2)删掉了drop层

(3)删除了含有loss层相关层

3、infer.py

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import sys
import caffe # the demo image is "2007_000129" from PASCAL VOC # load image, switch to BGR, subtract mean, and make dims C x H x W for Caffe
im = Image.open('coast_bea14.jpg')
in_ = np.array(im, dtype=np.float32)
in_ = in_[:,:,::-1]
in_ -= np.array((104.00698793,116.66876762,122.67891434))
in_ = in_.transpose((2,0,1)) # load net
net = caffe.Net('deploy.prototxt', 'snapshot/train_iter_100000.caffemodel', caffe.TEST)
# shape for input (data blob is N x C x H x W), set data
net.blobs['data'].reshape(1, *in_.shape)
net.blobs['data'].data[...] = in_
# run net and take argmax for prediction
net.forward()
sem_out = net.blobs['score_sem'].data[0].argmax(axis=0) # plt.imshow(out,cmap='gray');
plt.imshow(sem_out)
plt.axis('off')
plt.savefig('coast_bea14_sem_out.png')
sem_out_img = Image.fromarray(sem_out.astype('uint8')).convert('RGB')
sem_out_img.save('coast_bea14_sem_img_out.png') geo_out = net.blobs['score_geo'].data[0].argmax(axis=0)
plt.imshow(geo_out)
plt.axis('off')
plt.savefig('coast_bea14_geo_out.png')
geo_out_img = Image.fromarray(geo_out.astype('uint8')).convert('RGB')
geo_out_img.save('coast_bea14_geo_img_out.png')

其中,sem_out_img保存着语义分割的结果,geo_out_img保存场景标识的结果。

4、测试

python infer.py

Sift-flow中的图片都为256*256*3的彩色图片

images保存的是数据,semanticlabels保存的是语义分割标签,一共33类(而标注的数据会多一个无效类)。geolabels保存场景识别标签,共3类(而标注的数据会多一个无效类)。

所以是分别训练了两个网络,网络的前七层一样。

其中coast_bea14_sem_out.png为语义分割的结果, coast_bea14_geo_out.png为场景标识的结果,

原图                                                  语义分割                                                 场景标识

end

siftflow-fcn32s训练及预测的更多相关文章

  1. fcn训练及预测tgs数据集

    一.背景 kaggle上有这样一个题目,关于盐份预测的语义分割题目.TGS Salt Identification Challenge | Kaggle  https://www.kaggle.com ...

  2. 机器学习使用sklearn进行模型训练、预测和评价

    cross_val_score(model_name, x_samples, y_labels, cv=k) 作用:验证某个模型在某个训练集上的稳定性,输出k个预测精度. K折交叉验证(k-fold) ...

  3. 初识Sklearn-IrisData训练与预测

    笔记:机器学习入门---鸢尾花分类 Sklearn 本身就有很多数据库,可以用来练习. 以 Iris 的数据为例,这种花有四个属性,花瓣的长宽,茎的长宽,根据这些属性把花分为三类:山鸢尾花Setosa ...

  4. Spark技术在京东智能供应链预测的应用——按照业务进行划分,然后利用scikit learn进行单机训练并预测

    3.3 Spark在预测核心层的应用 我们使用Spark SQL和Spark RDD相结合的方式来编写程序,对于一般的数据处理,我们使用Spark的方式与其他无异,但是对于模型训练.预测这些需要调用算 ...

  5. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

  6. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  7. GUI:GUI的方式创建/训练/仿真/预测神经网络—Jason niu

    (1)导入数据:点击最左底部Import 按钮 (2)创建模型network_Jason_niu:点击底部的New按钮 (3)设置参数并训练:点击底部的Open按钮 (4)仿真预测: 大功告成!

  8. TensorFlow 1.4利用Keras+Estimator API进行训练和预测

    Tensorflow 1.4中,Keras作为作为核心模块可以直接通过tf.keas进行调用,但是考虑到keras对tfrecords文件进行操作比较麻烦,而将keras模型转成tensorflow中 ...

  9. 用C实现单隐层神经网络的训练和预测(手写BP算法)

    实验要求:•实现10以内的非负双精度浮点数加法,例如输入4.99和5.70,能够预测输出为10.69•使用Gprof测试代码热度 代码框架•随机初始化1000对数值在0~10之间的浮点数,保存在二维数 ...

随机推荐

  1. hadoop.docker.up.problems: Too many levels of symbolic links

    #root@c7hp:~ excp c78 "zkServer.sh start"[1] 11:49:44 [FAILURE] c78 Exited with error code ...

  2. makefile笔记1 - 初识makefile

    前情提要 上一篇<编译入门>讲了变成的基本问题.如果源文件只有一个,就如之前的例子,那么用gcc命令直接编译就可以了.但是很多实际的工程用到的源文件都是相当多的,这时候用命令一个个编译是很 ...

  3. 微信小程序调用快递物流查询API的实现方法

    一. 创建index.wxml.index.wxss.index.js 附上代码: <view class='container'> <input class='info' plac ...

  4. django websocket

    1.dwebsocket 2.等框架都是错误的 3.  django/channels 才是正确姿势 555 4. pip install -U channels 完成后,您应该添加channels到 ...

  5. ios jenkins从0快速配置

    1,安装:brew install jenkins2,命令行里:jenkins 回车,第一次会生成密码和保存密码的路径如:/Users/uname/.jenkins/secrets/initialAd ...

  6. github项目

    一.github项目地址: https://github.com/fairy1231/gitLearning/tree/master 二.github的重要性: Git 是一个快速.可扩展的分布式版本 ...

  7. Python高阶函数和匿名函数

    高阶函数:就是把函数当成参数传递的一种函数:例如 注解: 1.调用add函数,分别执行abs(-8)和abs(11),分别计算出他们的值 2.最后在做和运算 map()函数 python内置的一个高阶 ...

  8. Django models 的字段类型

    1.models.AutoField   ---自增列 = int(11)    如果没有的话,默认会生成一个名称为 id 的列,如果要显示的自定义一个自增列,必须将给列设置为主键 primary_k ...

  9. es6学习笔记-Proxy、Reflect、Promise

    Proxy Proxy 用于修改某些操作的默认行为,等同于在语言层面做出修改,所以属于一种“元编程”(meta programming),即对编程语言进行编程. Proxy 可以理解成,在目标对象之前 ...

  10. string 迭代器

    #include <iostream>#include <string>#include<algorithm>#define m 10000000using nam ...