TensorFlow keras卷积神经网络 添加L2正则化
model = keras.models.Sequential([
#卷积层1
keras.layers.Conv2D(32,kernel_size=5,strides=1,padding="same",data_format="channels_last",activation=tf.nn.relu,kernel_regularizer=keras.regularizers.l2(0.01)),
#池化层1
keras.layers.MaxPool2D(pool_size=2,strides=2,padding="same"),
#卷积层2
keras.layers.Conv2D(64,kernel_size=5,strides=1,padding="same",data_format="channels_last",activation=tf.nn.relu),
#池化层2
keras.layers.MaxPool2D(pool_size=2,strides=2,padding="same"),
#数据整理
keras.layers.Flatten(),
#1024个,全连接层
keras.layers.Dense(1024,activation=tf.nn.relu),
#100个,全连接层
keras.layers.Dense(100,activation=tf.nn.softmax)
])
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' from tensorflow.python.keras.datasets import cifar100
from tensorflow.python import keras
import tensorflow as tf class CNNMnist(object): model = keras.models.Sequential([
#卷积层1
keras.layers.Conv2D(32,kernel_size=5,strides=1,padding="same",data_format="channels_last",activation=tf.nn.relu,kernel_regularizer=keras.regularizers.l2(0.01)),
#池化层1
keras.layers.MaxPool2D(pool_size=2,strides=2,padding="same"),
#卷积层2
keras.layers.Conv2D(64,kernel_size=5,strides=1,padding="same",data_format="channels_last",activation=tf.nn.relu),
#池化层2
keras.layers.MaxPool2D(pool_size=2,strides=2,padding="same"),
#数据整理
keras.layers.Flatten(),
#1024个,全连接层
keras.layers.Dense(1024,activation=tf.nn.relu),
#100个,全连接层
keras.layers.Dense(100,activation=tf.nn.softmax)
]) def __init__(self):
(self.x_train,self.y_train),(self.x_test,self.y_test) = cifar100.load_data() self.x_train = self.x_train/255.0
self.x_test = self.x_test/255.0 def compile(self):
CNNMnist.model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.sparse_categorical_crossentropy,metrics=["accuracy"]) def fit(self):
CNNMnist.model.fit(self.x_train,self.y_train,epochs=1,batch_size=32) def evaluate(self):
test_loss,test_acc = CNNMnist.model.evaluate(self.x_test,self.y_test)
print(test_loss,test_acc) if __name__ == '__main__':
cnn = CNNMnist()
print(CNNMnist.model.summary())
cnn.compile()
cnn.fit()
TensorFlow keras卷积神经网络 添加L2正则化的更多相关文章
- TensorFlow实现卷积神经网络
1 卷积神经网络简介 在介绍卷积神经网络(CNN)之前,我们需要了解全连接神经网络与卷积神经网络的区别,下面先看一下两者的结构,如下所示: 图1 全连接神经网络与卷积神经网络结构 虽然上图中显示的全连 ...
- 使用TensorFlow的卷积神经网络识别自己的单个手写数字,填坑总结
折腾了几天,爬了大大小小若干的坑,特记录如下.代码在最后面. 环境: Python3.6.4 + TensorFlow 1.5.1 + Win7 64位 + I5 3570 CPU 方法: 先用MNI ...
- tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图
tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...
- TensorFlow构建卷积神经网络/模型保存与加载/正则化
TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...
- Python之TensorFlow的卷积神经网络-5
一.卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度 ...
- 【Python】keras卷积神经网络识别mnist
卷积神经网络的结构我随意设了一个. 结构大概是下面这个样子: 代码如下: import numpy as np from keras.preprocessing import image from k ...
- tensorflow 中的L1和L2正则化
import tensorflow as tf weights = tf.constant([[1.0, -2.0],[-3.0 , 4.0]]) >>> sess.run(tf.c ...
- 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集
import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...
- Tensorflow之卷积神经网络(CNN)
前馈神经网络的弊端 前一篇文章介绍过MNIST,是采用的前馈神经网络的结构,这种结构有一个很大的弊端,就是提供的样本必须面面俱到,否则就容易出现预测失败.如下图: 同样是在一个图片中找圆形,如果左边为 ...
随机推荐
- Jmeter 中 Bean Shell 之全局变量
1.新建测试计划>线程组 > http 请求 -登录 获取token , 可以参照我以前写的这篇博客 https://www.cnblogs.com/cyit/p/12632445.htm ...
- Cygwin工具编译Ardupilot方法
注意:该编译方法生成的固件基于Chibios系统,如果想要Nuttx系统固件,需采用make编译,步骤见make编译说明部分. 软件安装准备 安装Cygwin 打开链接www.cygwin.com/i ...
- B 方块消消乐
时间限制 : - MS 空间限制 : - KB 评测说明 : 1s,128m 问题描述 何老板在玩一款消消乐游戏,游戏虽然简单,何老板仍旧乐此不疲.游戏一开始有n个边长为1的方块叠成一个高为n的 ...
- JDBC下Date类型转换问题
一.前言 在学过MVC后,其中的DAO层是负责与数据库进行进行数据交互,而service层个servlet层需要数据时,不允许直接向数据库要,而是通过Dao层来获取相关数据.这个时候,就引出一个规定& ...
- python中装饰器的使用
看个例子: # 定义装饰器函数 def log(func): """ 接受一个函数作为参数,并返回一个函数 :param func: :return: "&qu ...
- 在.NET Core中检查证书的到期日期
在 NUnit 测试中,我需要检查证书的有效期. 下面的代码片段可用于使用自定义证书验证回调检查任何证书属性. 所有你需要做的就是在回调中读取你感兴趣的属性,这样你就可以在之后检查它们. DateTi ...
- 【php】面向对象(二)
一. 封装: a) 描述:使用成员修饰符修饰成员属性和成员方法,能够最大限度的隐藏对象内部的细节,保证对象的安全 b) PPP修饰符:public(公共的),protected(受保护的),priva ...
- rest_framework-序列化-1
序列化 定义模型类 from django.db import models # Create your models here. class StuModel(models.Model): SEX_ ...
- Vue-cli2.0 第3节 解读Vue-cli模板
Vue-cli2.0 第3节 解读Vue-cli模板 目录 Vue-cli2.0 第3节 解读Vue-cli模板 第3节 解读Vue-cli模板 1. npm run build命令 2. main. ...
- MTK Android Camera新增差值
一. 计算需要的插值 如果原有的插值列表没有我们需要的插值的时候,要通过计算算出符合需求的插值,比如2700W的插值. 具体计算方法如下: 假设像素的长宽分别为X,Y,则插值为XY.由于MTK规定各参 ...