import tensorflow as tf
from numpy.random import RandomState batch_size = 8
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
w1= tf.Variable(tf.random_normal([2, 1], stddev=1, seed=1))
y = tf.matmul(x, w1) # 定义损失函数使得预测少了的损失大,于是模型应该偏向多的方向预测。
loss_less = 10
loss_more = 1
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_) * loss_more, (y_ - y) * loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) rdm = RandomState(1)
X = rdm.rand(128,2)
Y = [[x1+x2+(rdm.rand()/10.0-0.05)] for (x1, x2) in X] with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

loss_less = 1
loss_more = 10
loss = tf.reduce_sum(tf.where(tf.greater(y, y_), (y - y_) * loss_more, (y_ - y) * loss_less))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

loss = tf.losses.mean_squared_error(y, y_)
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
STEPS = 5000
for i in range(STEPS):
start = (i*batch_size) % 128
end = (i*batch_size) % 128 + batch_size
sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
if i % 1000 == 0:
print("After %d training step(s), w1 is: " % (i))
print sess.run(w1), "\n"
print "Final w1 is: \n", sess.run(w1)

吴裕雄 python 神经网络——TensorFlow 自定义损失函数的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  2. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  3. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  4. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  5. 吴裕雄 python 神经网络——TensorFlow 数据集高层操作

    import tempfile import tensorflow as tf train_files = tf.train.match_filenames_once("E:\\output ...

  6. 吴裕雄 python 神经网络——TensorFlow 输入数据处理框架

    import tensorflow as tf files = tf.train.match_filenames_once("E:\\MNIST_data\\output.tfrecords ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣识别2

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用滑动平均

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

随机推荐

  1. 关于BaiduPSC-Go的一些bug的更正

    首先说下操作步骤 下载是在GutHub,这个不赘述,网上很多资料 下载之后配置环境变量,在path的后面加上一个分号,然后加上你下载的目录,目录名最好为英文 然后通过命令行CMD工具,输入BaiduP ...

  2. Attribute "resultType" must be declared for element type "update" or "insert"

    仔细查看错误如图所示: 解决错误就是把resultType去掉,因为在insert和update语句中是没有返回值的.小坑小坑 转自:https://blog.csdn.net/u013144287/ ...

  3. DFA 简易正则表达式匹配

    一个只能匹配非常简单的(字母 . + *)共 4 种状态的正则表达式语法的自动机(注意,仅限 DFA,没考虑 NFA): 好久之前写的了,记得有个 bug 一直没解决... #include < ...

  4. The entity type XXX is not part of the model for the current context.

    今天遇到了一个奇葩问题,虽然解决了,但还是一脸懵,先附赠一下别人的解决方案:https://www.cnblogs.com/zwjaaron/archive/2012/06/08/2541430.ht ...

  5. [lua]紫猫lua教程-命令宝典-L1-01-05. if判断结构

    L1[if]01. 简单的if判断结构 没什么说得 if得基本结构如下 xxx= ) then testlib.traceprint("1-100") ) then testlib ...

  6. CSS3的一个伪类选择器:nth-child()

    CSS3的一个伪类选择器“:nth-child()”. Table表格奇偶数行定义样式: 语法: :nth-child(an+b) 为什么选择她,因为我认为,这个选择器是最多学问的一个了.很可惜,据我 ...

  7. Servlt入门

    Servlt入门 java的两种体系结构 C/S (客户端/服务器)体系结构  通讯效率高且安全,但系统占用多 B/S (浏览器/服务器)体系结构    节约开发成本 C/S (客户端/服务器)体系结 ...

  8. C语言程序设计(二)

    目录:   1.算法基本概念 2.认识循环语句 3.算法的表示法 4.求素数 5.求闰年 6.判断一个数是否为回文数 算法基本概念: (一)一个程序主要包含的2方面信息: 1.对数据的描述,在程序中要 ...

  9. TM1638控制

    原理图图纸: 显示控制与读按键与寄存器的对应 驱动代码:码云:

  10. contextField 键盘只允许输入数字和小数点,并且现在小数点后位数

    - (BOOL)textField:(UITextField *)textField shouldChangeCharactersInRange:(NSRange)range replacementS ...