# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu() # 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt')
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 训练模型
# 1.1 前向传播
solver.net.forward() # train net
solver.test_nets[0].forward() # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()

如果想在训练过程中保存模型参数,调用

solver.net.save('mymodel.caffemodel')

caffe Python API 之Model训练的更多相关文章

  1. caffe Python API 之图片预处理

    # 设定图片的shape格式为网络data层格式 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) ...

  2. caffe Python API 之BatchNormal

    net.bn = caffe.layers.BatchNorm( net.conv1, batch_norm_param=dict( moving_average_fraction=0.90, #滑动 ...

  3. caffe Python API 之可视化

    一.显示各层 # params显示:layer名,w,b for layer_name, param in net.params.items(): print layer_name + '\t' + ...

  4. caffe Python API 之中值转换

    # 编写一个函数,将二进制的均值转换为python的均值 def convert_mean(binMean,npyMean): blob = caffe.proto.caffe_pb2.BlobPro ...

  5. caffe Python API 之Solver定义

    from caffe.proto import caffe_pb2 s = caffe_pb2.SolverParameter() path='/home/xxx/data/' solver_file ...

  6. caffe Python API 之激活函数ReLU

    import sys import os sys.path.append("/projects/caffe-ssd/python") import caffe net = caff ...

  7. caffe Python API 之 数据输入层(Data,ImageData,HDF5Data)

    import sys sys.path.append('/projects/caffe-ssd/python') import caffe4 net = caffe.NetSpec() 一.Image ...

  8. caffe Python API 之上卷积层(Deconvolution)

    对于convolution: output = (input + 2 * p  - k)  / s + 1; 对于deconvolution: output = (input - 1) * s + k ...

  9. caffe Python API 之Inference

    #以SSD的检测测试为例 def detetion(image_dir,weight,deploy,resolution=300): caffe.set_mode_gpu() net = caffe. ...

随机推荐

  1. java中的error该不该捕获

    写java程序时,通常会被提示捕获异常,而又有一些异常是不需要强制捕获的,这是一个被说烂了的话题.像我一样从其他语言转过来的人确实有点迷惑,那我以我的理解重新解释一遍吧. 异常的基类是Exceptio ...

  2. bug:margin塌陷

    margin塌陷:两个嵌套的div,内部div的margin-top失效,内部对于外部的div并没有产生一个margin值,而是外部的div相对于上面的div产生了一个margin值. 弥补方法: 1 ...

  3. [三]SpringBoot 之 热部署

    如下配置 <plugin> <groupId>org.springframework.boot</groupId> <artifactId>spring ...

  4. 洛谷 [USACO09OPEN]工作调度

    题面 读完题,我们会发现有一个很重要的信息,每件物品代价相同,但价值不同.那么我们很容易想到,在满足限制的情况下,我们肯定会选择价值尽可能大的物品. 我们可否用背包来实现呢,答案是否定的,或者说我不会 ...

  5. java 调用 keytool 生成keystore 和 cer 证书

    keytool是一个Java数据证书的管理工具, keytool将密钥(key)和证书(certificates)存在一个称为keystore的文件中在keystore里, 包含两种数据:密钥实体(K ...

  6. 解题:BZOJ 4808 马

    题面 以前写过的题,翻出来学习网络流写二分图匹配,因为复杂度更优秀,$Dinic$是$O(sqrt(n)m)$哒~ 原点向左部点连流量为$1$的边,左部点向对应右部点连流量为$1$的边,右部点向汇点连 ...

  7. Git入门指南

    git学习资源: Pro Git(中文版) Learn Git in your browser for free with Try Git. Git常用命令 Reference 常用 Git 命令清单 ...

  8. tf.slice函数解析

    tf.slice函数解析 觉得有用的话,欢迎一起讨论相互学习~Follow Me tf.slice(input_, begin, size, name = None) 解释 : 这个函数的作用是从输入 ...

  9. centos6.5 mqtt安装

    CentOs 6.5 MQTT 安装部署 所需安装包: libwebsockets-v1.6-stable.tar.gz,mosquitto-1.4.8.tar.gz 1.安装依赖 # yum -y ...

  10. 最小生成树的边的概念问题!!! 最小生成树的计数 bzoj 1016

    1016: [JSOI2008]最小生成树计数 Time Limit: 1 Sec  Memory Limit: 162 MBSubmit: 5292  Solved: 2163[Submit][St ...