python caffe 在师兄的代码上修改成自己风格的代码
首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667
(部分)代码分为:
1 train_net.py
#import some module
import time
import os
import numpy as np
import sys
import cv2
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
#from prepare_data import DataConfig
#from data_config import DataConfig #configure GPU mode
''' uncommend below line to use gpu '''
caffe.set_mode_gpu() # about dataset
##dataset = Dataset('/home/wang/Downloads/object/extract/')
##dataset = dataset.Split('train')
##data_config = DataConfig(dataset)
##data_config.SetBatchSize(256)
data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' #configure solve.prototxt
solver = caffe.SGDSolver('models/solver.prototxt') # load pretrain model
print('load pretrain model')
solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') solver.net.layers[0].SetDataConfig(data_config) for i in range(1, 10000):
# Make one SGD update
solver.step(5)
if i % 100 == 0:
solver.net.save('tmp.caffemodel')
''' TODO: test code '''
2 test_net.py
#import setup
import time
import os
import random
import sys
sys.path.append("/home/wang/Downloads/caffe-master/python")
import caffe
import cv2
import numpy as np
import random from utils import PrepareImage
#from dataset import Dataset
from test_data import test_data_pre test_num_once=10 ''' uncommend below line to use gpu '''
# caffe.set_mode_gpu() # dataset
#dataset = Dataset('/home/wang/Downloads/object/extract/')
#dataset = dataset.Split('test') # load net
net = caffe.Net('models/deploy.prototxt', caffe.TEST) # load train model
print('load pretrain model')
net.copy_from('tmp.caffemodel') #test all samples one by one
data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'
#(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]
(imgPaths, gt_label)=test_data_pre(data_pre)
num_img = len(imgPaths)
correct_num=0
for idx in range(num_img):
img = cv2.imread(imgPaths[idx])
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
tmp_img = img.copy() # for display
img = PrepareImage(img, (227, 227))
net.blobs['data'].reshape(test_num_once, 3, 227, 227)
net.blobs['data'].data[...] = img
#net.blobs['data'].data[i,:,:,:] = img
net.forward()
score = net.blobs['cls_prob'].data
if score.argmax()==gt_label[idx]:
correct_num=correct_num+1
if idx%100==0:
print("Please wait some minutes...")
correct_rate=correct_num*1.0/num_img
print('The correct rate is :',correct_rate)
3 test_data.py
import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def test_data_pre(path):
img_list=[]
image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))
label = np.zeros(image_num, dtype=np.float32) i=0
for idf in range(3):
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1)
for idi in range(len(tmp_path)):
img_path=path1+'/'+tmp_path[idi]
img_list.append(img_path)
label[i]=idf
i=i+1
return ( img_list,label)
4 pre_data.py
import os
import numpy as np
from random import randint
import cv2
from utils import PrepareImage,CatImage
#class data:
#path should be /home/
def prepare_data(path,batchsize):
#tmp_path=os.listdir(path)
img_list=[]
label = np.zeros(batchsize, dtype=np.float32)
for i in range(batchsize):
#randomly select one file
idf=randint(0,2)
idf_str=str(idf)
path1=path+idf_str
tmp_path=os.listdir(path1) #randomly select one image
idi=randint(0,len(tmp_path)-1)
#img = cv2.imread(imgPaths[idx])
img_path=path1+'/'+tmp_path[idi]
img=cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
flip = randint(0, 1)>0
if flip > 0:
img = img[:, ::-1, :] # flip left to right img=PrepareImage(img, (227,227))
img_list.append(img)
label[i]=idf
imgData = CatImage(img_list)
return (imgData,label)
5 utils.py
import os
import cv2
import numpy as np def PrepareImage(im, size):
im = cv2.resize(im, (size[0], size[1]))
im = im.transpose(2, 0, 1)
im = im.astype(np.float32, copy=False)
return im def CatImage(im_list):
max_shape = np.array([im.shape for im in im_list]).max(axis=0)
blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
# set to mean value
blob[:, 0, :, :] = 102.9801
blob[:, 1, :, :] = 115.9465
blob[:, 2, :, :] = 122.7717
for i, im in enumerate(im_list):
blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
return blob
6 layer/data_layer.py
import caffe
import numpy as np #import data_config
#import prepare_data
from pre_data import prepare_data class DataLayer(caffe.Layer): def SetDataConfig(self, data_config):
self._data_config = data_config def GetDataConfig(self):
return self._data_config def setup(self, bottom, top):
# data blob
top[0].reshape(1, 3, 227, 227)
#top[0].reshape(1, 3, 34, 44)
# label type
top[1].reshape(1, 1) def reshape(self, bootom, top):
pass def forward(self, bottom, top):
#(imgs, label) = self._data_config.next()
path=self.GetDataConfig()
(imgs,label)=prepare_data(path,128)
(N, C, W, H) = imgs.shape
# image data
top[0].reshape(N, C, W, H)
top[0].data[...] = imgs
# object type label
top[1].reshape(N)
top[1].data[...] = label def backward(self, top, propagate_down, bottom):
pass
7 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
代码和数据:
python caffe 在师兄的代码上修改成自己风格的代码的更多相关文章
- 用Python给你的代码上个进度条吧 | 【代码也要面子的】
微信公众号:AI算法与图像处理如果你觉得对你有帮助,欢迎关注.转发以及点赞哦-( ̄▽ ̄-)~ 前言 最近在跑一些代码的时候,很烦...因为有时候不知道这段程序什么时候能执行完,现在执行哪里了,如果报错 ...
- Upsource——对已签入的代码进行分享、讨论和审查代码
Upsource 一.Upsource简介 Upsource ,这是一个专门为软件开发团队所设计的源代码协作工具.Upsource能够与多种版本控制工具进行集成,包括Git.Mercurial.Sub ...
- python之模块ftplib(实现ftp上传下载代码)
# -*- coding: utf-8 -*- #python 27 #xiaodeng #python之模块ftplib(实现ftp上传下载代码) #需求:实现ftp上传下载代码(不含错误处理) f ...
- 学习Git的一点心得以及如何把本地修改、删除的代码上传到github中
一:学习Github的资料如下:https://git.oschina.net/progit/ 这是一个学习Git的中文网站,如果诸位能够静下心来阅读,不要求阅读太多,只需要阅读前三章,就可以掌握Gi ...
- 基于Caffe的DeepID2实现(上)
小喵的唠叨话:小喵最近在做人脸识别的工作,打算将汤晓鸥前辈的DeepID,DeepID2等算法进行实验和复现.DeepID的方法最简单,而DeepID2的实现却略微复杂,并且互联网上也没有比较好的资源 ...
- 使用pycharm开发代码上传到GitLab和GitHub
使用pycharm开发代码上传到GitLab和GitHub 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 我这里主要是针对局域网的自减的GitLab服务器,python开发工程师如 ...
- python 全栈开发,Day86(上传文件,上传头像,CBV,python读写Excel,虚拟环境virtualenv)
一.上传文件 上传一个图片 使用input type="file",来上传一个文件.注意:form表单必须添加属性enctype="multipart/form-data ...
- 使用git工具将本地电脑上的代码上传至GitHub
本文教你如果使用git工具将本地电脑上的代码上传至GitHub 1.安装git工具 安装git链接 2.使用git工具上传自己的代码到GitHub中 安装完git工具之后,我们会得到两个命令行工具,一 ...
- Dynamics AX 2012 R2 窗体系列 - 在窗体上修改字段时所触发的方法及其顺序
在这个系列里,Reinhard将和大家一起探索在AX的窗体上执行操作时,都会触发窗体.窗体数据源和表上的哪些方法,并且是以怎样的顺序触发的. 这次,我们来看看在窗体上修改或录入数据的情 ...
随机推荐
- [转载]Javassist 使用指南(一)
======================= 本文转载自简书,感谢原作者!. 原链接如下:https://www.jianshu.com/p/43424242846b =============== ...
- centos7配置安装redis
关闭防火墙:systemctl stop firewalld.service #停止firewallsystemctl disable firewalld.service #禁止firewall开机启 ...
- Why not inherit from List<T>?
问题: When planning out my programs, I often start with a chain of thought like so: A football team is ...
- 复习指南(Pascal版)
[第一层级 条件反射] 1.个十百千各数位的求法 q:=a div 1000 mod 10; b:=a div 100 mod 10; s:=a div 10 mod 10; g:=a mod 10; ...
- Asp.Net MVC 缓存设计
Asp.Net MVC 缓存: 1. 可以直接在Controller,Action上面定义输出缓存OutputCache,如下,第一次请求这个Index的时候,里面的代码会执行,并且结果会被缓存起来, ...
- Outlook.com 系列邮箱 POP3 及 IMAP 设置方法
支持 Exchange ActiveSync 的应用 有了 EAS,你可以立即获取电子邮件,以及在一个位置查看所有文件夹.日历和联系人. 如果你的电子邮件应用支持Exchange ActiveSync ...
- codeforces781C Underground Lab
本文版权归ljh2000和博客园共有,欢迎转载,但须保留此声明,并给出原文链接,谢谢合作. 本文作者:ljh2000 作者博客:http://www.cnblogs.com/ljh2000-jump/ ...
- 关于JNDI那点事
一.JNDI是什么? JNDI--Java 命名和目录接口(Java Naming and Directory Interface),是一组在Java应用中访问命名和目录服务的API. 二.JNDI好 ...
- java23种设计模式之二: 单例设计模式(6种写法)
目的:在某些业务场景中,我们需要某个类的实例对象的只能有一个,因此我们需要创建一些单例对象. 本文共有6种写法,仅供参考 1.饿汉式 优点: 在多线程情况下,该方法创建的单例是线程安全的(立即加载) ...
- JS中函数定义和函数表达式的区别
摘要: (function() {})();和(function(){}());的区别 Javascript中有2个语法都与function关键字有关,分别是: 函数定义:function Funct ...