博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
『TensorFlow』正则化添加方法整理
阅读量:6312 次
发布时间:2019-06-22

本文共 3788 字,大约阅读时间需要 12 分钟。

一、基础正则化函数

tf.contrib.layers.l1_regularizer(scale, scope=None)

返回一个用来执行L1正则化的函数,函数的签名是func(weights)

参数:

  • scale: 正则项的系数.
  • scope: 可选的scope name

tf.contrib.layers.l2_regularizer(scale, scope=None)

先看看tf.contrib.layers.l2_regularizer(weight_decay)都执行了什么:

import tensorflow as tfsess=tf.Session()weight_decay=0.1tmp=tf.constant([0,1,2,3],dtype=tf.float32)"""l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) """#**上面代码的等价代码a=tf.get_variable("I_am_a",initializer=tmp)a2=tf.reduce_sum(a*a)*weight_decay/2;a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)#**sess.run(tf.global_variables_initializer())keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)for key in keys:  print("%s : %s" %(key.name,sess.run(key)))
我们很容易可以模拟出tf.contrib.layers.l2_regularizer都做了什么,不过会让代码变丑。
以下比较完整实现L2 正则化。
import tensorflow as tfsess=tf.Session()weight_decay=0.1                                                #(1)定义weight_decayl2_reg=tf.contrib.layers.l2_regularizer(weight_decay)           #(2)定义l2_regularizer()tmp=tf.constant([0,1,2,3],dtype=tf.float32)a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)  #(3)创建variable,l2_regularizer复制给regularizer参数。                                                                #目测REXXX_LOSSES集合#regularizer定义会将a加入REGULARIZATION_LOSSES集合print("Global Set:")keys = tf.get_collection("variables")for key in keys:  print(key.name)print("Regular Set:")keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)for key in keys:  print(key.name)print("--------------------")sess.run(tf.global_variables_initializer())print(sess.run(a))reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)   #(4)则REGULARIAZTION_LOSSES集合会包含所有被weight_decay后的参数和,将其相加l2_loss=tf.add_n(reg_set)print("loss=%s" %(sess.run(l2_loss)))"""此处输出0.7,即:   weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7其实代码自己写也很方便,用API看着比较正规。在网络模型中,直接将l2_loss加入loss就好了。(loss变大,执行train自然会decay)"""

二、添加正则化方法

a、原始办法

正则化常用到集合,下面是最原始的添加正则办法(直接在变量声明后将之添加进'losses'集合或tf.GraphKeys.LOESSES也行):

import tensorflow as tfimport numpy as npdef get_weights(shape, lambd):    var = tf.Variable(tf.random_normal(shape), dtype=tf.float32)    tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(lambd)(var))    return varx = tf.placeholder(tf.float32, shape=(None, 2))y_ = tf.placeholder(tf.float32, shape=(None, 1))batch_size = 8layer_dimension = [2, 10, 10, 10, 1]n_layers = len(layer_dimension)cur_lay = xin_dimension = layer_dimension[0]for i in range(1, n_layers):    out_dimension = layer_dimension[i]    weights = get_weights([in_dimension, out_dimension], 0.001)    bias = tf.Variable(tf.constant(0.1, shape=[out_dimension]))    cur_lay = tf.nn.relu(tf.matmul(cur_lay, weights)+bias)    in_dimension = layer_dimension[i]mess_loss = tf.reduce_mean(tf.square(y_-cur_lay))tf.add_to_collection('losses', mess_loss)loss = tf.add_n(tf.get_collection('losses'))

b、tf.contrib.layers.apply_regularization(regularizer, weights_list=None)

先看参数

  • regularizer:就是我们上一步创建的正则化方法
  • weights_list: 想要执行正则化方法的参数列表,如果为None的话,就取GraphKeys.WEIGHTS中的weights.

函数返回一个标量Tensor,同时,这个标量Tensor也会保存到GraphKeys.REGULARIZATION_LOSSES中.这个Tensor保存了计算正则项损失的方法.

tensorflow中的Tensor是保存了计算这个值的路径(方法),当我们run的时候,tensorflow后端就通过路径计算出Tensor对应的值

现在,我们只需将这个正则项损失加到我们的损失函数上就可以了.

如果是自己手动定义weight的话,需要手动将weight保存到GraphKeys.WEIGHTS中,但是如果使用layer的话,就不用这么麻烦了,别人已经帮你考虑好了.(最好自己验证一下tf.GraphKeys.WEIGHTS中是否包含了所有的weights,防止被坑)

c、使用slim

使用slim会简单很多:

with slim.arg_scope([slim.conv2d, slim.fully_connected],                            activation_fn=tf.nn.relu,                            weights_regularizer=slim.l2_regularizer(weight_decay)):    pass

此时添加集合为tf.GraphKeys.REGULARIZATION_LOSSES。

转载地址:http://rahxa.baihongyu.com/

你可能感兴趣的文章
rhel6下安装配置Squid过程
查看>>
《树莓派开发实战(第2版)》——1.1 选择树莓派型号
查看>>
在 Linux 下使用 fdisk 扩展分区容量
查看>>
结合AlphaGo算法和大数据的量化基本面分析法探讨
查看>>
如何在 Ubuntu Linux 16.04 LTS 中使用多个连接加速 apt-get/apt
查看>>
《OpenACC并行编程实战》—— 导读
查看>>
机器学习:用初等数学解读逻辑回归
查看>>
如何在 Ubuntu 中管理和使用逻辑卷管理 LVM
查看>>
Oracle原厂老兵:从负面案例看Hint的最佳使用方式
查看>>
把自己Github上的代码添加Cocoapods支持
查看>>
C语言OJ项目参考(2493)四则运算
查看>>
零基础入门深度学习(二):神经网络和反向传播算法
查看>>
find和xargs
查看>>
数据结构例程—— 交换排序之快速排序
查看>>
WKWebView代理方法解析
查看>>
IOS定位服务的应用
查看>>
[SMS&WAP]实例讲解制作OTA短信来自动配置手机WAP书签[附源码]
查看>>
IOS中图片(UIImage)拉伸技巧
查看>>
【工具】系统性能查看工具 dstat
查看>>
基于zepto或jquery的手机端弹出框成功,失败,加载特效
查看>>