caffe Python API 之Model训练
# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu() # 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt')
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 训练模型
# 1.1 前向传播
solver.net.forward() # train net
solver.test_nets[0].forward() # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()
如果想在训练过程中保存模型参数,调用
solver.net.save('mymodel.caffemodel')
caffe Python API 之Model训练的更多相关文章
- caffe Python API 之图片预处理
# 设定图片的shape格式为网络data层格式 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) ...
- caffe Python API 之BatchNormal
net.bn = caffe.layers.BatchNorm( net.conv1, batch_norm_param=dict( moving_average_fraction=0.90, #滑动 ...
- caffe Python API 之可视化
一.显示各层 # params显示:layer名,w,b for layer_name, param in net.params.items(): print layer_name + '\t' + ...
- caffe Python API 之中值转换
# 编写一个函数,将二进制的均值转换为python的均值 def convert_mean(binMean,npyMean): blob = caffe.proto.caffe_pb2.BlobPro ...
- caffe Python API 之Solver定义
from caffe.proto import caffe_pb2 s = caffe_pb2.SolverParameter() path='/home/xxx/data/' solver_file ...
- caffe Python API 之激活函数ReLU
import sys import os sys.path.append("/projects/caffe-ssd/python") import caffe net = caff ...
- caffe Python API 之 数据输入层(Data,ImageData,HDF5Data)
import sys sys.path.append('/projects/caffe-ssd/python') import caffe4 net = caffe.NetSpec() 一.Image ...
- caffe Python API 之上卷积层(Deconvolution)
对于convolution: output = (input + 2 * p - k) / s + 1; 对于deconvolution: output = (input - 1) * s + k ...
- caffe Python API 之Inference
#以SSD的检测测试为例 def detetion(image_dir,weight,deploy,resolution=300): caffe.set_mode_gpu() net = caffe. ...
随机推荐
- java中的error该不该捕获
写java程序时,通常会被提示捕获异常,而又有一些异常是不需要强制捕获的,这是一个被说烂了的话题.像我一样从其他语言转过来的人确实有点迷惑,那我以我的理解重新解释一遍吧. 异常的基类是Exceptio ...
- bug:margin塌陷
margin塌陷:两个嵌套的div,内部div的margin-top失效,内部对于外部的div并没有产生一个margin值,而是外部的div相对于上面的div产生了一个margin值. 弥补方法: 1 ...
- [三]SpringBoot 之 热部署
如下配置 <plugin> <groupId>org.springframework.boot</groupId> <artifactId>spring ...
- 洛谷 [USACO09OPEN]工作调度
题面 读完题,我们会发现有一个很重要的信息,每件物品代价相同,但价值不同.那么我们很容易想到,在满足限制的情况下,我们肯定会选择价值尽可能大的物品. 我们可否用背包来实现呢,答案是否定的,或者说我不会 ...
- java 调用 keytool 生成keystore 和 cer 证书
keytool是一个Java数据证书的管理工具, keytool将密钥(key)和证书(certificates)存在一个称为keystore的文件中在keystore里, 包含两种数据:密钥实体(K ...
- 解题:BZOJ 4808 马
题面 以前写过的题,翻出来学习网络流写二分图匹配,因为复杂度更优秀,$Dinic$是$O(sqrt(n)m)$哒~ 原点向左部点连流量为$1$的边,左部点向对应右部点连流量为$1$的边,右部点向汇点连 ...
- Git入门指南
git学习资源: Pro Git(中文版) Learn Git in your browser for free with Try Git. Git常用命令 Reference 常用 Git 命令清单 ...
- tf.slice函数解析
tf.slice函数解析 觉得有用的话,欢迎一起讨论相互学习~Follow Me tf.slice(input_, begin, size, name = None) 解释 : 这个函数的作用是从输入 ...
- centos6.5 mqtt安装
CentOs 6.5 MQTT 安装部署 所需安装包: libwebsockets-v1.6-stable.tar.gz,mosquitto-1.4.8.tar.gz 1.安装依赖 # yum -y ...
- 最小生成树的边的概念问题!!! 最小生成树的计数 bzoj 1016
1016: [JSOI2008]最小生成树计数 Time Limit: 1 Sec Memory Limit: 162 MBSubmit: 5292 Solved: 2163[Submit][St ...