一、基础正则化函数

tf.contrib.layers.l1_regularizer(scale, scope=None)

返回一个用来执行L1正则化的函数,函数的签名是func(weights)
参数:

  • scale: 正则项的系数.
  • scope: 可选的scope name

tf.contrib.layers.l2_regularizer(scale, scope=None)

先看看tf.contrib.layers.l2_regularizer(weight_decay)都执行了什么:

import tensorflow as tf
sess=tf.Session()
weight_decay=0.1
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
"""
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)
"""
#**上面代码的等价代码
a=tf.get_variable("I_am_a",initializer=tmp)
a2=tf.reduce_sum(a*a)*weight_decay/2;
a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)
#**
sess.run(tf.global_variables_initializer())
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:
print("%s : %s" %(key.name,sess.run(key)))
我们很容易可以模拟出tf.contrib.layers.l2_regularizer都做了什么,不过会让代码变丑。
以下比较完整实现L2 正则化。
import tensorflow as tf
sess=tf.Session()
weight_decay=0.1 #(1)定义weight_decay
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay) #(2)定义l2_regularizer()
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) #(3)创建variable,l2_regularizer复制给regularizer参数。
#目测REXXX_LOSSES集合
#regularizer定义会将a加入REGULARIZATION_LOSSES集合
print("Global Set:")
keys = tf.get_collection("variables")
for key in keys:
print(key.name)
print("Regular Set:")
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:
print(key.name)
print("--------------------")
sess.run(tf.global_variables_initializer())
print(sess.run(a))
reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #(4)则REGULARIAZTION_LOSSES集合会包含所有被weight_decay后的参数和,将其相加
l2_loss=tf.add_n(reg_set)
print("loss=%s" %(sess.run(l2_loss)))
"""
此处输出0.7,即:
weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7
其实代码自己写也很方便,用API看着比较正规。
在网络模型中,直接将l2_loss加入loss就好了。(loss变大,执行train自然会decay)
"""

二、添加正则化方法

a、原始办法

正则化常用到集合,下面是最原始的添加正则办法(直接在变量声明后将之添加进'losses'集合或tf.GraphKeys.LOESSES也行):

import tensorflow as tf
import numpy as np def get_weights(shape, lambd): var = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(lambd)(var))
return var x = tf.placeholder(tf.float32, shape=(None, 2))
y_ = tf.placeholder(tf.float32, shape=(None, 1))
batch_size = 8
layer_dimension = [2, 10, 10, 10, 1]
n_layers = len(layer_dimension)
cur_lay = x
in_dimension = layer_dimension[0] for i in range(1, n_layers):
out_dimension = layer_dimension[i]
weights = get_weights([in_dimension, out_dimension], 0.001)
bias = tf.Variable(tf.constant(0.1, shape=[out_dimension]))
cur_lay = tf.nn.relu(tf.matmul(cur_lay, weights)+bias)
in_dimension = layer_dimension[i] mess_loss = tf.reduce_mean(tf.square(y_-cur_lay))
tf.add_to_collection('losses', mess_loss)
loss = tf.add_n(tf.get_collection('losses'))

b、tf.contrib.layers.apply_regularization(regularizer, weights_list=None)

先看参数

  • regularizer:就是我们上一步创建的正则化方法
  • weights_list: 想要执行正则化方法的参数列表,如果为None的话,就取GraphKeys.WEIGHTS中的weights.

函数返回一个标量Tensor,同时,这个标量Tensor也会保存到GraphKeys.REGULARIZATION_LOSSES中.这个Tensor保存了计算正则项损失的方法.

tensorflow中的Tensor是保存了计算这个值的路径(方法),当我们run的时候,tensorflow后端就通过路径计算出Tensor对应的值

现在,我们只需将这个正则项损失加到我们的损失函数上就可以了.

如果是自己手动定义weight的话,需要手动将weight保存到GraphKeys.WEIGHTS中,但是如果使用layer的话,就不用这么麻烦了,别人已经帮你考虑好了.(最好自己验证一下tf.GraphKeys.WEIGHTS中是否包含了所有的weights,防止被坑)

c、使用slim

使用slim会简单很多:

 with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(weight_decay)):
pass

此时添加集合为tf.GraphKeys.REGULARIZATION_LOSSES。

『TensorFlow』正则化添加方法整理的更多相关文章

  1. 『TensorFlow』专题汇总

    TensorFlow:官方文档 TensorFlow:项目地址 本篇列出文章对于全零新手不太合适,可以尝试TensorFlow入门系列博客,搭配其他资料进行学习. Keras使用tf.Session训 ...

  2. 『TensorFlow』滑动平均

    滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量. 1.滑动平均求解对象初始化 ema = tf.train.Ex ...

  3. 『TensorFlow』模型保存和载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  4. 『TensorFlow』读书笔记_进阶卷积神经网络_分类cifar10_上

    完整项目见:Github 完整项目中最终使用了ResNet进行分类,而卷积版本较本篇中结构为了提升训练效果也略有改动 本节主要介绍进阶的卷积神经网络设计相关,数据读入以及增强在下一节再与介绍 网络相关 ...

  5. 『TensorFlow』TFR数据预处理探究以及框架搭建

    一.TFRecord文件书写效率对比(单线程和多线程对比) 1.准备工作 # Author : Hellcat # Time : 18-1-15 ''' import os os.environ[&q ...

  6. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  7. 『TensorFlow』流程控制

    『PyTorch』第六弹_最小二乘法对比PyTorch和TensorFlow TensorFlow 控制流程操作 TensorFlow 提供了几个操作和类,您可以使用它们来控制操作的执行并向图中添加条 ...

  8. 『TensorFlow』读书笔记_降噪自编码器

    『TensorFlow』降噪自编码器设计  之前学习过的代码,又敲了一遍,新的收获也还是有的,因为这次注释写的比较详尽,所以再次记录一下,具体的相关知识查阅之前写的文章即可(见上面链接). # Aut ...

  9. 『TensorFlow』梯度优化相关

    tf.trainable_variables可以得到整个模型中所有trainable=True的Variable,也是自由处理梯度的基础 基础梯度操作方法: tf.gradients 用来计算导数.该 ...

随机推荐

  1. :after和:before 伪类

    1 使用伪类画三角形 .div{ margin-top: 100px; margin-left: 100px; } .div:after{ content: ''; display:inline-bl ...

  2. Django-- KindEditor 富文本编辑器使用

    KindEditor是一款还不错的开源的HTML可视化编辑器,主要用于让用户在网站上获得所见即所得编辑效果,兼容IE.Firefox.Chrome.Safari.Opera等主流浏览器.之所以推荐这一 ...

  3. 微信小游戏跳一跳简单手动外挂(基于adb 和 python)

    只有两个python文件,代码很简单. shell.py: #coding:utf-8 import subprocess import math import os def execute_comm ...

  4. Python递归函数介绍

    一.递归的定义 1.什么是递归:在一个函数里在调用这个函数本身 2.最大递归层数做了一个限制:997,但是也可以自己限制 # 验证 997 def foo(n): print(n) n+=1 foo( ...

  5. Redis入门到高可用(十六)—— 持久化

    一.持久化概念 二.持久化方式 三.redis持久化方式之——RDB 1.什么是RDB 在 Redis 运行时, RDB 程序将当前内存中的数据库快照保存到磁盘文件中, 在 Redis 重启动时, R ...

  6. swagger:API在线文档自动生成框架

    传统的API从开发测试开始我们经常借用类似Postman.fiddle等等去做接口测试等等工具:Swagger 为API的在线测试.在线文档提供了一个新的简便的解决方案: NET 使用Swagger ...

  7. Dubbo分布式服务框架

    Dubbo (开源分布式服务框架) 编辑 本词条缺少信息栏,补充相关内容使词条更完整,还能快速升级,赶紧来编辑吧! Dubbo是 [1]  阿里巴巴公司开源的一个高性能优秀的服务框架,使得应用可通过高 ...

  8. CentOS 7 DR模式LVS搭建

    调度器LB : 192.168.94.11 真实web服务器1 : 192.168.94.22 真实web服务器2 : 192.168.94.33 VIP : 192.168.94.111 脚本如下 ...

  9. swapper_pg_dir主内核页表、init和kthreadd、do_fork时新建子进程页表、vmalloc与kmalloc

    都是以前看到一个点扯出的很多东西,当时做的总结,有问题欢迎讨论,现在来源难寻,侵删! 1.Init_task.idle.init和kthreadd的区别和联系 idle进程其pid=0,其前身是系统创 ...

  10. 《CSS世界》读书笔记(一)

    <!-- <CSS世界> 张鑫旭 著 --> CSS世界构建的基石是HTML,而HTML最具代表的两个基石<div>和<span>正好是CSS世界中块级 ...