import tensorflow as tf
import os
from matplotlib import pyplot as plt
import tensorflow.keras.datasets
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.layers import Dense,Flatten,BatchNormalization,Dropout,Conv2D,Activation,MaxPool2D
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train=x_train/255.
x_test=x_test/255. class AlexNet(Model):
def __init__(self):
super(AlexNet, self).__init__()
self.c1=Conv2D(filters=96,kernel_size=(3,3),strides=1,padding='valid')
self.b1=BatchNormalization()
self.a1=Activation('relu')
self.p1=MaxPool2D(pool_size=(3,3),strides=2) self.c2 = Conv2D(filters=384, kernel_size=(3, 3), strides=1, padding='same')
#self.b2 = BatchNormalization()
self.a2 = Activation('relu')
#self.p2 = MaxPool2D(pool_size=(3, 3), strides=2) self.c3 = Conv2D(filters=256, kernel_size=(3, 3), strides=1, padding='same')
# self.b2 = BatchNormalization()
self.a3 = Activation('relu')
self.p3 = MaxPool2D(pool_size=(3, 3), strides=2) self.flatten=Flatten()
self.f1 = Dense(2048,activation='relu')
self.d1=Dropout(0.5)
self.f2 = Dense(2048, activation='relu')
self.d2 = Dropout(0.5)
self.f3 = Dense(10, activation='softmax') def call(self,x): x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x) x = self.c2(x)
x = self.a2(x) x = self.c3(x)
x = self.a3(x)
x = self.p3(x) x = self.flatten(x) x=self.f1(x)
x=self.d1(x)
x=self.f2(x)
x=self.d2(x)
y=self.f3(x)
return y model=AlexNet() model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']) check_save_path='./checkpoint/AlexNet.ckpt'
if os.path.exists(check_save_path+'.index'):
print('-------------lodel the model------------')
model.load_weights(check_save_path) cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,save_best_only=True,
save_weights_only=True) history=model.fit(x_train,y_train,batch_size=128,epochs=5,validation_data=(x_test,y_test),
validation_freq=1,callbacks=[cp_callback]) model.summary() file=open('./AlexNet_wights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.np()) + '\n')
file.close() ############可视化图像###############
acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['sparse_categorical_val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss'] plt.subplot(1,2,1)
plt.plot(acc)
plt.plot(val_acc)
plt.legend() plt.subplot(1,2,2)
plt.plot(loss)
plt.plot(val_loss)
plt.legend() plt.show()

此代码运行较慢,单次遍历需要近15分钟,由此可见两层全连接层2048个神经元远远拖慢运行速度

AlexNet实现cifar10数据集分类的更多相关文章

  1. 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类

    这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...

  2. 用pytorch进行CIFAR-10数据集分类

    CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky.Vinod Nair 与 Geoffrey Hinton 收 ...

  3. python实现HOG+SVM对CIFAR-10数据集分类(上)

    本博客只用于学习,如果有错误的地方,恳请指正,如需转载请注明出处. 看机器学习也是有一段时间了,这两天终于勇敢地踏出了第一步,实现了HOG+SVM对图片分类,具体代码可以在github上下载,http ...

  4. CIFAR-10数据集图像分类【PCA+基于最小错误率的贝叶斯决策】

    CIFAR-10和CIFAR-100均是带有标签的数据集,都出自于规模更大的一个数据集,他有八千万张小图片.而本次实验采用CIFAR-10数据集,该数据集共有60000张彩色图像,这些图像是32*32 ...

  5. Ubuntu+caffe训练cifar-10数据集

    1. 下载cifar-10数据库 ciffar-10数据集包含10种物体分类,50000张训练图片,10000张测试图片. 在终端执行指令下载cifar-10数据集(二进制文件): cd ~/caff ...

  6. caffe︱cifar-10数据集quick模型的官方案例

    准备拿几个caffe官方案例用来练习,就看到了caffe中的官方案例有cifar-10数据集.于是练习了一下,在CPU情况下构建quick模型.主要参考博客:liumaolincycle的博客 配置: ...

  7. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  8. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  9. Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes

    Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...

随机推荐

  1. 星屑幻想 optimal mark

    LINK :SP839 星屑幻想 取自 OJ 的名称 小事情...题目大意还是要说的这道题比较有意思,想了一段时间. 给你一张图 这张图给答案带来的贡献是每条边上两个点值得异或 一些点的值已经被确定 ...

  2. 用大白话的方式讲明白Java的StringBuilder、StringBuffer的扩容机制

    StringBuffer和StringBuilder,它们的底层char数组value默认的初始化容量是16,扩容只需要修改底层的char数组,两者的扩容最终都会调用到AbstractStringBu ...

  3. 【问题记录】springMVC @Valid使用不生效问题

    问题描述 在网上找到如何使用@Valid注解后,就把用到的配置和jar包加上,然后测试发现一直不生效.下面是配置及解决方法 配置 1.引入依赖 2.添加相应的配置(springmvc配置文件) < ...

  4. 搭建Redis主从复制的集群

    在主从复制模式的集群里,主节点一般是一个,从节点一般是两个或多个,写入主节点的数据会被复制到从节点上,这样一旦主节点出现故障,应用系统能切换到从节点去读写数据,这样能提升系统的可用性.而且如果再采用主 ...

  5. 能动手绝不多说:开源评论系统remark42上手指南

    能动手绝不多说:开源评论系统 remark42 上手指南 前言 写博客嘛, 谁不喜欢自己倒腾一下呢. 从自建系统到 Github Page, 从 Jekyll 到 Hexo, 年轻的时候谁不喜欢多折腾 ...

  6. GitLab Admin Area 500 Error

    GitLab Admin Area 500 Error GitLab Admin Area Settings 菜单全部报错 500 解决方法 执行: gitlab-rake cache:clear # ...

  7. [深度学习] Pytorch学习(一)—— torch tensor

    [深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...

  8. C#LeetCode刷题之#724-寻找数组的中心索引( Find Pivot Index)

    问题 该文章的最新版本已迁移至个人博客[比特飞],单击链接 https://www.byteflying.com/archives/3742 访问. 给定一个整数类型的数组 nums,请编写一个能够返 ...

  9. 听说同学你搞不懂Java的LinkedHashMap,可笑

    先看再点赞,给自己一点思考的时间,微信搜索[沉默王二]关注这个有颜值却假装靠才华苟且的程序员.本文 GitHub github.com/itwanger 已收录,里面还有我精心为你准备的一线大厂面试题 ...

  10. IT技术人,“三十而已”

    最近电视剧<三十而已>热播,我家的电视机自然也是被霸屏,我还是跟着妹纸看了看,开头和结局完整看完,中间看了一点,大部分都是在微信公众号上通过别人的文章看完的.我个人也已经30+了,今天也和 ...