import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer #load data
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3) #
# add layer
#
def add_layer(inputs, in_size, out_size, n_layer, activation_function = None):
layer_name = 'layer%s' % n_layer Weights = tf.Variable(tf.random_normal([in_size, out_size]), name='W') # hang lie
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = 'b') Wx_plus_b = tf.matmul(inputs, Weights) + biases
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob) # if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b) tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs #
# define placeholder for inputs to network
#
keep_prob = tf.placeholder(tf.float32) #
xs = tf.placeholder(tf.float32, [None, 64]) # 8x8
ys = tf.placeholder(tf.float32, [None, 10]) #
# add output layer
#
l1 = add_layer(xs, 64, 50, 'l1', activation_function = tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function = tf.nn.softmax) #
# the error between prediction and real data
#
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) #loss
tf.summary.scalar('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy) sess = tf.Session()
merged = tf.summary.merge_all() #summary writer goes here
train_writer = tf.summary.FileWriter("logs/train", sess.graph)
test_writer = tf.summary.FileWriter("logs/test", sess.graph) sess.run(tf.global_variables_initializer()) for i in range(500):
#sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:1.0}) # overfitted
sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:0.5}) # keep 0.5, drop 0.5
if i% 50 == 0:
#record loss
train_result = sess.run(merged, feed_dict={xs:X_train, ys:y_train, keep_prob:1})
test_result = sess.run(merged, feed_dict={xs:X_test, ys:y_test, keep_prob:1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)

  

莫烦TensorFlow_10 过拟合的更多相关文章

  1. tensorflow 莫烦教程

    1,感谢莫烦 2,第一个实例:用tf拟合线性函数 import tensorflow as tf import numpy as np # create data x_data = np.random ...

  2. tensorflow学习笔记-bili莫烦

    bilibili莫烦tensorflow视频教程学习笔记 1.初次使用Tensorflow实现一元线性回归 # 屏蔽警告 import os os.environ[' import numpy as ...

  3. 【莫烦Pytorch】【P1】人工神经网络VS. 生物神经网络

    滴:转载引用请注明哦[握爪] https://www.cnblogs.com/zyrb/p/9700343.html 莫烦教程是一个免费的机器学习(不限于)的学习教程,幽默风俗的语言让我们这些刚刚起步 ...

  4. 稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记。

    稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记. 还有 google 在 udacity 上的 CNN 教程. CNN(Convolutional Neural Networks) 卷积神经网络简单 ...

  5. 莫烦大大TensorFlow学习笔记(9)----可视化

      一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...

  6. scikit-learn学习笔记-bili莫烦

    bilibili莫烦scikit-learn视频学习笔记 1.使用KNN对iris数据分类 from sklearn import datasets from sklearn.model_select ...

  7. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  8. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  9. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

随机推荐

  1. php 学习笔记之搭建开发环境(mac版)

    Mac 系统默认集成了很多开发工具,其中就包括 php 所需要的一些软件工具. 下面我们将搭建最简单的 php 开发环境,每一步都会验证上一步的操作结构,请一步一步跟我一起搭建吧! web 服务器之 ...

  2. 【CF525E】Anya and Cubes(meet in middle)

    点此看题面 大致题意: 在\(n\)个数中选任意个数,并使其中至多\(k\)个数\(x_i\)变为\(x_i!\),求使这些数和为\(S\)的方案数. \(meet\ in\ middle\) 这应该 ...

  3. python统计wav文件的时长

    import wave import os.path # 音频存放文件夹绝对路径 filedir = '/Users/111/PycharmProjects/TextClassify/wav' lis ...

  4. Windows7运行python3,提示缺少api-ms-win-crt-runtime-l1-1.0.dll

    一.实验环境 1.Windows7x64_SP1 二.操作步骤 2.1 python官网下载python3.6后,安装.运行,提示如下错误: 2.2 解决方式 去微软官网下载安装:KB2999226补 ...

  5. Oracle索引知识学习笔记

    目录 一.Oracle索引简介 1.1 索引分类 1.2 索引数据结构 1.3 索引特性 1.4 索引使用注意要点 1.5.索引的缺点 1.6.索引失效 二.索引分类介绍 2.1.位图索引 1.2.函 ...

  6. 九、Spring之BeanFactory源码分析(一)

    Spring之BeanFactory源码分析(一) ​ 注意:该随笔内容完全引自https://blog.csdn.net/u014634338/article/details/82865644,写的 ...

  7. tensorflow: arg_scope()

    with arg_scope(): 1.允许我们设定一些共享参数,并将其进行保存,必要时还可以嵌套覆盖 2.在指定的函数调用时,可以将一些默认参数塞进去. 接下来看一个tensorflow自带的例子. ...

  8. SiftingAppender logback 动态 输出 日志 到指定日志文件

    SiftingAppender https://www.mkyong.com/logging/logback-different-log-file-for-each-thread/

  9. [Flutter] 转一个Flutter学习思维导图

    本文的思维导图均转自QQ群,感谢原作者(是谁?) 表单 按钮 视图 Sliver 路由 (Routes) 输入控件 对话框 MDC (Material Design Component) 状态管理 R ...

  10. yii2.0的学习之旅(二)

    前言:上一次我们简单认识了一下yii2.0安装,模型基本(增,删,改,查)操作 一.前后台数据交互 *如果你觉得默认的top样式太丑,可以这样关掉* *底部也可以这样关掉* (1)mvc合作操作数据 ...