import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer #load data
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3) #
# add layer
#
def add_layer(inputs, in_size, out_size, n_layer, activation_function = None):
layer_name = 'layer%s' % n_layer Weights = tf.Variable(tf.random_normal([in_size, out_size]), name='W') # hang lie
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = 'b') Wx_plus_b = tf.matmul(inputs, Weights) + biases
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob) # if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b) tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs #
# define placeholder for inputs to network
#
keep_prob = tf.placeholder(tf.float32) #
xs = tf.placeholder(tf.float32, [None, 64]) # 8x8
ys = tf.placeholder(tf.float32, [None, 10]) #
# add output layer
#
l1 = add_layer(xs, 64, 50, 'l1', activation_function = tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function = tf.nn.softmax) #
# the error between prediction and real data
#
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) #loss
tf.summary.scalar('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy) sess = tf.Session()
merged = tf.summary.merge_all() #summary writer goes here
train_writer = tf.summary.FileWriter("logs/train", sess.graph)
test_writer = tf.summary.FileWriter("logs/test", sess.graph) sess.run(tf.global_variables_initializer()) for i in range(500):
#sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:1.0}) # overfitted
sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:0.5}) # keep 0.5, drop 0.5
if i% 50 == 0:
#record loss
train_result = sess.run(merged, feed_dict={xs:X_train, ys:y_train, keep_prob:1})
test_result = sess.run(merged, feed_dict={xs:X_test, ys:y_test, keep_prob:1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)

  

莫烦TensorFlow_10 过拟合的更多相关文章

  1. tensorflow 莫烦教程

    1,感谢莫烦 2,第一个实例:用tf拟合线性函数 import tensorflow as tf import numpy as np # create data x_data = np.random ...

  2. tensorflow学习笔记-bili莫烦

    bilibili莫烦tensorflow视频教程学习笔记 1.初次使用Tensorflow实现一元线性回归 # 屏蔽警告 import os os.environ[' import numpy as ...

  3. 【莫烦Pytorch】【P1】人工神经网络VS. 生物神经网络

    滴:转载引用请注明哦[握爪] https://www.cnblogs.com/zyrb/p/9700343.html 莫烦教程是一个免费的机器学习(不限于)的学习教程,幽默风俗的语言让我们这些刚刚起步 ...

  4. 稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记。

    稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记. 还有 google 在 udacity 上的 CNN 教程. CNN(Convolutional Neural Networks) 卷积神经网络简单 ...

  5. 莫烦大大TensorFlow学习笔记(9)----可视化

      一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...

  6. scikit-learn学习笔记-bili莫烦

    bilibili莫烦scikit-learn视频学习笔记 1.使用KNN对iris数据分类 from sklearn import datasets from sklearn.model_select ...

  7. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  8. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  9. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

随机推荐

  1. NOIP 2011 计算系数

    洛谷 P1313 计算系数 洛谷传送门 JDOJ 1747: [NOIP2011]计算系数 D2 T1 JDOJ传送门 Description 给定一个多项式(ax + by)k,请求出多项式展开后x ...

  2. Python入门基础学习记录(二)汇率案例学习记录

    一.汇总整理 1.操作 ①新建python文件 工程右键--new--python file 2.注意问题与知识点 >变量定义:直接写变量名即可,例如定义一个字符串并赋值123: rmb_str ...

  3. springboot2.0 management.security.enabled无效

    在1.5.x版本中通过management.security.enabled=false来暴露所有端点 在使用springcloud的时候,如果基于springboot2的版本的配置中心,无法使用SV ...

  4. Chrome保存整个网页为图片

    打开需要保存为图片的网页 然后按F12,接着按Ctrl+Shift+P 在红框内输入full 点击下面的“Capture full size screenshot”就可以保存整个网页为图片了 原文出处 ...

  5. 自动化API之一 自动生成Mysql数据库的微服务API

        本文演示如何利用Uniconnector平台,自动生成Mysql数据库的API,节约开发人员编写后台API的时间.使用生成API的前提是开发者有 自己的数据库,有数据库的管理权限,并能通过外网 ...

  6. 尽解powershell的workflow

    尽解powershell的workflow -------1[简介]--------- Microsoft .NET Framework 4.0 发布于2010年4月左右..net4 的新特性,是并行 ...

  7. jQuery 源码解析(八) 异步队列模块 Callbacks 回调函数详解

    异步队列用于实现异步任务和回调函数的解耦,为ajax模块.队列模块.ready事件提供基础功能,包含三个部分:Query.Callbacks(flags).jQuery.Deferred(funct) ...

  8. 【01】Nginx:编译安装/动态添加模块

    写在前面的话 说起 Nginx,别说运维,就是很多开发人员也很熟悉,毕竟如今已经 2019 年了,Apache 更多的要么成为了历史,要么成为了历史残留. 我们在提及 Nginx 的时候,一直在强调他 ...

  9. .net基础加强

    1.冒泡排序 请通过冒泡排序法对整数数组{ 1, 3, 5, 7, 90, 2, 4, 6, 8, 10 }实现升序排序 , , , , , , , , , }; BubbleSort(num); C ...

  10. Python基础22

    数据类型可变不可变,说的是“指向”. 深浅拷贝.