『TensorFlow』正则化添加方法整理
一、基础正则化函数
tf.contrib.layers.l1_regularizer(scale, scope=None)
返回一个用来执行L1
正则化的函数,函数的签名是func(weights)
.
参数:
- scale: 正则项的系数.
- scope: 可选的
scope name
tf.contrib.layers.l2_regularizer(scale, scope=None)
先看看tf.contrib.layers.l2_regularizer(weight_decay)都执行了什么:
- import tensorflow as tf
- sess=tf.Session()
- weight_decay=0.1
- tmp=tf.constant([0,1,2,3],dtype=tf.float32)
- """
- l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)
- a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)
- """
- #**上面代码的等价代码
- a=tf.get_variable("I_am_a",initializer=tmp)
- a2=tf.reduce_sum(a*a)*weight_decay/2;
- a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)
- tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)
- #**
- sess.run(tf.global_variables_initializer())
- keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
- for key in keys:
- print("%s : %s" %(key.name,sess.run(key)))
- import tensorflow as tf
- sess=tf.Session()
- weight_decay=0.1 #(1)定义weight_decay
- l2_reg=tf.contrib.layers.l2_regularizer(weight_decay) #(2)定义l2_regularizer()
- tmp=tf.constant([0,1,2,3],dtype=tf.float32)
- a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) #(3)创建variable,l2_regularizer复制给regularizer参数。
- #目测REXXX_LOSSES集合
- #regularizer定义会将a加入REGULARIZATION_LOSSES集合
- print("Global Set:")
- keys = tf.get_collection("variables")
- for key in keys:
- print(key.name)
- print("Regular Set:")
- keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
- for key in keys:
- print(key.name)
- print("--------------------")
- sess.run(tf.global_variables_initializer())
- print(sess.run(a))
- reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #(4)则REGULARIAZTION_LOSSES集合会包含所有被weight_decay后的参数和,将其相加
- l2_loss=tf.add_n(reg_set)
- print("loss=%s" %(sess.run(l2_loss)))
- """
- 此处输出0.7,即:
- weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7
- 其实代码自己写也很方便,用API看着比较正规。
- 在网络模型中,直接将l2_loss加入loss就好了。(loss变大,执行train自然会decay)
- """
二、添加正则化方法
a、原始办法
正则化常用到集合,下面是最原始的添加正则办法(直接在变量声明后将之添加进'losses'集合或tf.GraphKeys.LOESSES也行):
- import tensorflow as tf
- import numpy as np
- def get_weights(shape, lambd):
- var = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
- tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(lambd)(var))
- return var
- x = tf.placeholder(tf.float32, shape=(None, 2))
- y_ = tf.placeholder(tf.float32, shape=(None, 1))
- batch_size = 8
- layer_dimension = [2, 10, 10, 10, 1]
- n_layers = len(layer_dimension)
- cur_lay = x
- in_dimension = layer_dimension[0]
- for i in range(1, n_layers):
- out_dimension = layer_dimension[i]
- weights = get_weights([in_dimension, out_dimension], 0.001)
- bias = tf.Variable(tf.constant(0.1, shape=[out_dimension]))
- cur_lay = tf.nn.relu(tf.matmul(cur_lay, weights)+bias)
- in_dimension = layer_dimension[i]
- mess_loss = tf.reduce_mean(tf.square(y_-cur_lay))
- tf.add_to_collection('losses', mess_loss)
- loss = tf.add_n(tf.get_collection('losses'))
b、tf.contrib.layers.apply_regularization(regularizer, weights_list=None)
先看参数
- regularizer:就是我们上一步创建的正则化方法
- weights_list: 想要执行正则化方法的参数列表,如果为
None
的话,就取GraphKeys.WEIGHTS
中的weights
.
函数返回一个标量Tensor
,同时,这个标量Tensor
也会保存到GraphKeys.REGULARIZATION_LOSSES
中.这个Tensor
保存了计算正则项损失的方法.
tensorflow
中的Tensor
是保存了计算这个值的路径(方法),当我们run的时候,tensorflow
后端就通过路径计算出Tensor
对应的值
现在,我们只需将这个正则项损失加到我们的损失函数上就可以了.
如果是自己手动定义
weight
的话,需要手动将weight
保存到GraphKeys.WEIGHTS
中,但是如果使用layer
的话,就不用这么麻烦了,别人已经帮你考虑好了.(最好自己验证一下tf.GraphKeys.WEIGHTS
中是否包含了所有的weights
,防止被坑)
c、使用slim
使用slim会简单很多:
- with slim.arg_scope([slim.conv2d, slim.fully_connected],
- activation_fn=tf.nn.relu,
- weights_regularizer=slim.l2_regularizer(weight_decay)):
- pass
此时添加集合为tf.GraphKeys.REGULARIZATION_LOSSES。
『TensorFlow』正则化添加方法整理的更多相关文章
- 『TensorFlow』专题汇总
TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...
- 『TensorFlow』滑动平均
滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...
- 『TensorFlow』模型保存和载入方法汇总
『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...
- 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_上
完整项目见:Github 完整项目中最终使用了ResNet进行分类,而卷积版本较本篇中结构为了提升训练效果也略有改动 本节主要介绍进阶的卷积神经网络设计相关,数据读入以及增强在下一节再与介绍 网络相关 ...
- 『TensorFlow』TFR数据预处理探究以及框架搭建
一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...
- 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍
一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...
- 『TensorFlow』流程控制
『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...
- 『TensorFlow』读书笔记_降噪自编码器
『TensorFlow』降噪自编码器设计 之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...
- 『TensorFlow』梯度优化相关
tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...
随机推荐
- Intellij IDEA注册激活破解
1.2017年适用(2016.3.5到2017.2.4版均生效) 安装IntelliJ IDEA 最新版 启动IntelliJ IDEA 输入 license时,选择输入 [License serve ...
- Jmeter学习之-获取登录的oken值(1)
ps: 这里只着重讲述如何实时获取其他接口返回的值,作为此次接口的参数传递,添加接口请求的相关不再详述,可查看上一篇文章 为了方便管理,此处将:登录接口单独放在一个线程组下面,需要使用登录接口返回的t ...
- java框架之SpringBoot(16)-分布式及整合Dubbo
前言 分布式应用 在分布式系统中,国内常用 Zookeeper + Dubbo 组合,而 SpringBoot 推荐使用 Spring 提供的分布式一站式解决方案 Spring + SpringBoo ...
- 用git如何把单个文件回退到某一版本
暂定此文件为a.jsp 1.首先到a.jsp所在目录: 通过 git log a.jsp 查看a.jsp的更改记录 2.找到想要回退的版本号:例如 fcd2093 通过 git reset fcd ...
- 【UML】NO.47.EBook.5.UML.1.007-【UML 大战需求分析】- 部署图(Deployment Diagram)
1.0.0 Summary Tittle:[UML]NO.47.EBook.1.UML.1.007-[UML 大战需求分析]- 部署图(Deployment Diagram) Style:Design ...
- oracle 新建用户后赋予的权限语句
grant create session,resource to itsys; grant create table to itsys;grant resource to itsys;grant cr ...
- CentOS 7 开机延迟解决办法
遇到这种情况 , 开机延迟 , 可以用下面的办法来查看 , 寻找到问题的源头 , 来看看到的是怎么回事 [root@DaMoWang ~]# dmesg |grep udev #显示系统的启动信息 ...
- 【转】jira迁移数据
jira迁移数据有两种方式 方式一: jira系统自带的备份恢复操作 最简单的,但不一定能成功 从/export/atlassian/application-data/jira/export下载至 ...
- mongodb认识
MongoDB安装 一.软件的获取地址 1.使用本安装文档提供的安装软件 本安装文档提供的软件为window版本的64位MongoDB的安装包 2.在官网上下载所需的安装软件 下载地址:https:/ ...
- 记录Js 文本框验证 与 IE兼容性
最近的日常就是将测试小姐姐提交的bug进行修改,想来这种事情还是比较好开展的,毕竟此项目已上线一年多,现在只是一些前端的问题需要改正.实际上手的时候并不是这样,原项目是在谷歌上运行,后来由于要新增一个 ...