AlexNet实现cifar10数据集分类
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()
此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度
AlexNet实现cifar10数据集分类的更多相关文章
- 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类
这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...
- 用pytorch进行CIFAR-10数据集分类
CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...
- python实现HOG+SVM对CIFAR-10数据集分类(上)
本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
随机推荐
- MySQL选错索引导致的线上慢查询事故
前言 又和大家见面了!又两周过去了,我的云笔记里又多了几篇写了一半的文章草稿.有的是因为质量没有达到预期还准备再加点内容,有的则完全是一个灵感而已,内容完全木有.羡慕很多大佬们,一周能产出五六篇文章, ...
- lamp分离部署
目录 lamp分离部署 1. 安装httpd 2. 安装mysql 3. 安装php 4. 配置apache并部署项目 4.1 启用代理模块 4.2 配置虚拟主机 4.3 部署PbootCMSPHP企 ...
- Tarjan算法 学习笔记
前排提示:先学习拓扑排序,再学习Tarjan有奇效. -------------------------- Tarjan算法一般用于有向图里强连通分量的缩点. 强连通分量:有向图里能够互相到达的点的集 ...
- PEP8之常用编码规范-摘自官网
PEP8是广泛应用于Python编码中的规范,这里只会记录最重要的一部分:摘自官网 使用4个空格缩进,不要使用制表符. 4个空格是一个在小缩进(允许更大的嵌套深度)和大缩进(更容易阅读)的一种很好的折 ...
- centos,linux环境下安装JDK1.8完整
进入oracle官网下载安装包,cetos一般选择xx-xx-linux-x64.tar.gz.获取到地址后可以点击下载,也可以使用wget命令下载. 在得到下载好的文件后下面就可以开始安装了.比如我 ...
- 8月份Python招聘情况怎么样?Python爬取招聘数据,并进行分析
前言 拉勾招聘是专业的互联网求职招聘平台.致力于提供真实可靠的互联网招聘求职找工作信息.今天我们一起使用 python 采集拉钩的 python 招聘信息,分析一下找到高薪工作需要掌握哪些技术 开发环 ...
- Git本地仓库基本操作
目录 设置姓名和邮箱 创建仓库 提交本地代码 .gitignore git add git commit git status git diff 查看提交记录 撤销未提交的修改 版本回退 设置姓名和邮 ...
- 前端面试 vue 部分 (5)——VUE组件之间通信的方式有哪些
VUE组件之间通信的方式有哪些(SSS) 常见使用场景可以分为三类: 父子通信: null 父向子传递数据是通过 props ,子向父是通过 $emit / $on $emit / $bus Vuex ...
- JS 本地存储笔记
本地存储 1.数据存储在用户浏览器中的 2.设置.读取方便.甚至刷新都不会丢失数据 3.容量比较大,sessionStorange约5M,localstorage约20M ...
- Java对象与类—对象与类
1.类 类(class)是构造对象的模板,具体点说类就是对具有相同性质,相同行为的一群对象的抽象说明.由类构造(construst)对象的过程称为创建类的实例(instance). 2.对象 对象是类 ...