AlexNet实现cifar10数据集分类
import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()
此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度
AlexNet实现cifar10数据集分类的更多相关文章
- 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类
这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...
- 用pytorch进行CIFAR-10数据集分类
CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...
- python实现HOG+SVM对CIFAR-10数据集分类(上)
本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...
- CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】
CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...
- Ubuntu+caffe训练cifar-10数据集
1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...
- caffe︱cifar-10数据集quick模型的官方案例
准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
随机推荐
- 浅谈二分图的最大匹配和二分图的KM算法
二分图还可以,但是我不太精通.我感觉这是一个很烦的问题但是学网络流不得不学它.硬啃吧. 人比较蠢,所以思考几天才有如下理解.希望能说服我或者说服你. 二分图的判定不再赘述一个图是可被划分成一个二分图当 ...
- IntelliJ IDEA 控制台输出中文乱码
IntelliJ IDEA 控制台输出中文乱码部分如图所示: 解决方法一: 1.打开IntelliJ IDEA本地安装目录中bin文件夹下的idea.exe.vmoptions和idea64.exe. ...
- Python网络数据采集PDF高清完整版免费下载|百度云盘|Python基础教程免费电子书
点击获取提取码:jrno 内容提要 本书采用简洁强大的 Python 语言,介绍了网络数据采集,并为采集新式网络中的各种数据类 型提供了全面的指导.第一部分重点介绍网络数据采集的基本原理:如何用 Py ...
- Netty(一):server启动流程解析
netty作为一个被广泛应用的通信框架,有必要我们多了解一点. 实际上netty的几个重要的技术亮点: 1. reactor的线程模型; 2. 安全有效的nio非阻塞io模型应用; 3. pipeli ...
- python range函数的用法
range 函数是Python内置函数.可创建一个整数列表,一般用在 for 循环中. 函数语法:range(start, stop[, step]) start: 计数从 start 开始.默认是从 ...
- SPM:Single-stage Multi-person Pose Machines
figure1图b figure1 -a figure3-a 图一-a
- C#LeetCode刷题之#21-合并两个有序链表(Merge Two Sorted Lists)
问题 该文章的最新版本已迁移至个人博客[比特飞],单击链接 https://www.byteflying.com/archives/3818 访问. 将两个有序链表合并为一个新的有序链表并返回.新链表 ...
- 虚拟化技术之kvm基础
一.KVM简介 KVM的全称是kernel base virtual machine(基于内核的虚拟机)是一个开源的系统虚拟化模块,自Linux 2.6.20之后集成在Linux的各个主要发行版本中. ...
- 前端路由、后端路由——想要学好vue-router 或者 node.js 必须得明白的两个概念
前端路由和后端路由的概念讲解 引言 正文 一.路由的概念 二.后端路由 三.前端路由 四.其他知识 结束语 引言 无论你是正在学习vue 还是在学习node, 你一定会碰到前端路由和后端路由这两个概念 ...
- JVM简记
1.JVM概述 JVM(Java virtual Machine)指以软件的方式模拟具有完整硬件系统功能.运行在一个完全隔离环境中的完整计算机系统 ,是物理机的软件实现. JVM是一种规范,实现产品常 ...