# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu() # 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt')
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 训练模型
# 1.1 前向传播
solver.net.forward() # train net
solver.test_nets[0].forward() # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()

如果想在训练过程中保存模型参数,调用

solver.net.save('mymodel.caffemodel')

caffe Python API 之Model训练的更多相关文章

  1. caffe Python API 之图片预处理

    # 设定图片的shape格式为网络data层格式 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) ...

  2. caffe Python API 之BatchNormal

    net.bn = caffe.layers.BatchNorm( net.conv1, batch_norm_param=dict( moving_average_fraction=0.90, #滑动 ...

  3. caffe Python API 之可视化

    一.显示各层 # params显示:layer名,w,b for layer_name, param in net.params.items(): print layer_name + '\t' + ...

  4. caffe Python API 之中值转换

    # 编写一个函数,将二进制的均值转换为python的均值 def convert_mean(binMean,npyMean): blob = caffe.proto.caffe_pb2.BlobPro ...

  5. caffe Python API 之Solver定义

    from caffe.proto import caffe_pb2 s = caffe_pb2.SolverParameter() path='/home/xxx/data/' solver_file ...

  6. caffe Python API 之激活函数ReLU

    import sys import os sys.path.append("/projects/caffe-ssd/python") import caffe net = caff ...

  7. caffe Python API 之 数据输入层(Data,ImageData,HDF5Data)

    import sys sys.path.append('/projects/caffe-ssd/python') import caffe4 net = caffe.NetSpec() 一.Image ...

  8. caffe Python API 之上卷积层(Deconvolution)

    对于convolution: output = (input + 2 * p  - k)  / s + 1; 对于deconvolution: output = (input - 1) * s + k ...

  9. caffe Python API 之Inference

    #以SSD的检测测试为例 def detetion(image_dir,weight,deploy,resolution=300): caffe.set_mode_gpu() net = caffe. ...

随机推荐

  1. BZOJ3560 DZY Loves Math V(欧拉函数)

    对每个质因子分开计算再乘起来.使用类似生成函数的做法就很容易统计了. #include<iostream> #include<cstdio> #include<cmath ...

  2. Day 3 学习笔记

    Day 3 学习笔记 STL 模板库 一.结构体 结构体是把你所需要的一些自定义的类型(原类型.实例(:包括函数)的集合)都放到一个变量包里. 然后这个变量包与原先的类型差不多,可以开数组,是一种数据 ...

  3. 【BZOJ1562】【NOI2009】变换序列(二分图匹配)

    [BZOJ1562][NOI2009]变换序列 题面 BZOJ 洛谷 这题面写的是真的丑,还是先手动翻译成人话. 让你构造一个\(0..N-1\)的排列\(T\) 使得\(Dis(i,T_i)\)为给 ...

  4. After ZJOI2017 day2

    4.28早上6点左右就起了床,怀着紧张的心情,候到了7:45进考场 看到题,先0.5h看了看题意,yy一下,至少10+20+10. 首先是觉得T3可以搞一搞,先想到SA,很快就X掉了,思索一会儿,感觉 ...

  5. 洛谷 P2604 [ZJOI2010]网络扩容 解题报告

    P2604 [ZJOI2010]网络扩容 题目描述 给定一张有向图,每条边都有一个容量C和一个扩容费用W.这里扩容费用是指将容量扩大1所需的费用.求: 1. 在不扩容的情况下,1到N的最大流: 2. ...

  6. sysbench - 单组件式测试工具

    1 安装 > ./configure --with-mysql-includes=/usr/local/mysql/include --with-mysql-libs=/usr/local/my ...

  7. Web服务器进程连接数和请求连接数

    1.查看Web服务器进程连接数: netstat -antp | grep 80 | grep ESTABLISHED -c 2.查看Web服务器的并发请求数及其TCP连接状态: netstat -n ...

  8. SQL通用优化方案(where优化、索引优化、分页优化、事务优化、临时表优化)

    SQL通用优化方案:1. 使用参数化查询:防止SQL注入,预编译SQL命令提高效率2. 去掉不必要的查询和搜索字段:其实在项目的实际应用中,很多查询条件是可有可无的,能从源头上避免的多余功能尽量砍掉, ...

  9. 浴谷金秋线上集训营 T11738 伪神(树链剖分)

    先树链剖分,一棵子树的编号在数组上连续,一条链用树链剖分,把这些线段全部取出来,做差分,找到有多少点被>=t条线段覆盖即可. #include<iostream> #include& ...

  10. duilib CDateTimeUI 在Xp下的bug修复

    转自:http://my.oschina.net/u/343244/blog/370131 CDateTimeUI 的bug修复.修改CDateTimeWnd的HandleMessage方法 ? 1 ...