# -*- coding=utf-8 -*-
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense,Flatten,Dropout
from keras.optimizers import Adadelta
from keras.datasets import cifar10
from keras import applications

import matplotlib.pyplot as plt
%matplotlib inline

vgg_model=applications.VGG19(include_top=False,weights='imagenet')
vgg_model.summary()

(train_x,train_y),(test_x,test_y)=cifar10.load_data()
print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

n_classes=10
train_y=keras.utils.to_categorical(train_y,n_classes)
test_y=keras.utils.to_categorical(test_y,n_classes)

bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

my_model=Sequential()
my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
my_model.add(Dense(512,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(256,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(n_classes,activation='softmax'))
my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",\
metrics=['accuracy'])
my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
print("loss:",evaluation[0],"accuracy:",evaluation[1])

def predict_label(img_idx,show_proba=True):
plt.imshow(train_x[img_idx],aspect='auto')
plt.title("Image to be labeled")
plt.show()
img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
print("Actual class:{0}\nPredict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

if show_proba:
pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
print(pred)

for i in range(3):
predict_label(i,show_proba=True)

吴裕雄 python神经网络(8)的更多相关文章

  1. 吴裕雄 python神经网络 花朵图片识别(10)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  2. 吴裕雄 python神经网络 花朵图片识别(9)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  3. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  4. 吴裕雄 python神经网络 水果图片识别(4)

    # coding: utf-8 # In[1]:import osimport numpy as npfrom skimage import color, data, transform, io # ...

  5. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  6. 吴裕雄 python神经网络 水果图片识别(2)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom skimage import color,data,transform,i ...

  7. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  10. 吴裕雄 python 神经网络——TensorFlow实现搭建基础神经网络

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

随机推荐

  1. KVM总结-KVM性能优化之磁盘IO优化

    前面讲了KVM CPU(http://blog.csdn.net/dylloveyou/article/details/71169463).内存(http://blog.csdn.net/dyllov ...

  2. python第三方库,你要的这里都有

    Python的第三方库多的超出我的想象. python 第三方模块 转 https://github.com/masterpy/zwpy_lst   Chardet,字符编码探测器,可以自动检测文本. ...

  3. Java 文件类 File

    1.File 类 1.File 类 1.1.构造方法 文件的 抽象路径名(操作系统无关) 构造方法 格式 说明 File(String filename) 把文件路径名字符串转换为“抽象路径名”,用来 ...

  4. 安装Anaconda3进行python版本管理

    1.下载Anaconda3,我选择了python3的64位版本 2.windows安装,选择加入了系统目录 3.进入命令行进行版本安装 // 安装一个指定版本conda create --name p ...

  5. 常见天气api

    1. 心知天气API1.1 免费版:400次/小时,也就是9600次/天.国内城市数据,天气实况,3天预报,6项生活指数.这个API的免费版已经提供了很多年了,应该算最长寿稳定的那批API了……1.2 ...

  6. 432 4.3.2 STOREDRV.Deliver; recipient thread limit exceeded

    最近几天Hub-Mailbox服务器时不时就CPU超过90%.在任务管理器里面看到edgetransport占用大量CPU.进入EMC的队列查看器,看到邮箱数据库堵塞,队列上万. 堵塞的邮件大多是收件 ...

  7. 31.用 CSS 的动画原理,创作一个乒乓球对打动画

    原文地址:https://segmentfault.com/a/1190000015002553 感想:纯属动画 HTML代码: <div class="court"> ...

  8. vue-i18n

    安装 npm install vue-i18n 初始化 import VueI18n from 'vue-i18n' Vue.use(VueI18n) const messages = { zh: { ...

  9. leetcode1003

    class Solution: def isValid(self, S: str) -> bool: n = len(S) if n % 3 != 0: return False while n ...

  10. win10 死机 无响应

    win10 死机 无响应 用着用着无响应,结束任务出不来,ctrl+alt+delete  无效. 点 窗口的关闭关闭不了. 鼠标键盘无响应. 写的代码变成乱码,影响太严重了,损失惨重. 紧急启动 c ...