# -*- coding=utf-8 -*-
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense,Flatten,Dropout
from keras.optimizers import Adadelta
from keras.datasets import cifar10
from keras import applications

import matplotlib.pyplot as plt
%matplotlib inline

vgg_model=applications.VGG19(include_top=False,weights='imagenet')
vgg_model.summary()

(train_x,train_y),(test_x,test_y)=cifar10.load_data()
print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

n_classes=10
train_y=keras.utils.to_categorical(train_y,n_classes)
test_y=keras.utils.to_categorical(test_y,n_classes)

bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

my_model=Sequential()
my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
my_model.add(Dense(512,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(256,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(n_classes,activation='softmax'))
my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",\
metrics=['accuracy'])
my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
print("loss:",evaluation[0],"accuracy:",evaluation[1])

def predict_label(img_idx,show_proba=True):
plt.imshow(train_x[img_idx],aspect='auto')
plt.title("Image to be labeled")
plt.show()
img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
print("Actual class:{0}\nPredict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

if show_proba:
pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
print(pred)

for i in range(3):
predict_label(i,show_proba=True)

吴裕雄 python神经网络(8)的更多相关文章

  1. 吴裕雄 python神经网络 花朵图片识别(10)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  2. 吴裕雄 python神经网络 花朵图片识别(9)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  3. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  4. 吴裕雄 python神经网络 水果图片识别(4)

    # coding: utf-8 # In[1]:import osimport numpy as npfrom skimage import color, data, transform, io # ...

  5. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  6. 吴裕雄 python神经网络 水果图片识别(2)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom skimage import color,data,transform,i ...

  7. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  10. 吴裕雄 python 神经网络——TensorFlow实现搭建基础神经网络

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

随机推荐

  1. django-response对象

    HttpResponse 对象则需要 web 开发者自己创建,一般在视图函数中 return 回去.下面我们就来看看 HttpResponse 对象的各种细节 首先,这个对象由 HttpRespons ...

  2. shell脚本遍历子目录

    #!/bin/bashsource /etc/profile tool_path=/data/rsync_clientroot_path=/data/log ####yyyy-mm-dd¸ñʽdat ...

  3. 让MySql支持表情符号(MySQL中4字节utf8字符保存方法)

    UTF-8编码有可能是两个.三个.四个字节.Emoji表情是4个字节,而MySQL的utf8编码最多3个字节,所以数据插不进去. 解决方案:将编码从utf8转换成utf8mb4. 1. 修改my.in ...

  4. MyBatis Generator 生成数据库自带中文注释

    1. maven依赖 <!-- mybatis生成 jar包 --> <dependency> <groupId>org.mybatis.generator< ...

  5. 详解CSS3属性前缀(转)

    原文地址 CSS3的属性为什么要带前缀 使用过CSS3属性的同学都知道,CSS3属性都需要带各浏览器的前缀,甚至到现在,依然还有很多属性需要带前缀.这是为什么呢? 我的理解是,浏览器厂商以前就一直在实 ...

  6. Shell 格式化输出数字、字符串(printf)

    1.语法 printf打印格式字符串,解释'%'指令和'\'转义. 1.1.转义 printf使用时需要指定输出格式,输出后不换行. printf FORMAT [ARGUMENT] printf O ...

  7. django-权限验证场景

    1.需要登录才能够访问的验证 from django.contrib.auth.decorators import login_required # 登录装饰器 # method_decorator ...

  8. gentoo openrc 开机打印信息

    gentoo openrc 开机的时候,最开始 一些硬件的信息, 后面是一些内核和驱动的信息. 硬件的信息是默认保存到 /var/log/dmesg 中, 可以使用 dmesg | less 来查看硬 ...

  9. redis的键命令

    键的命令 查找键,参数支持正则 KEYS pattern 判断键是否存在,如果存在返回1,不存在返回0 EXISTS key [key ...] 查看键对应的value的类型 TYPE key 删除键 ...

  10. Linux 硬链接、软链接

    索引节点 inode(index node) 我们知道文件都有文件名与数据,这在 Linux 上被分成两个部分:用户数据 (user data) 与元数据 (metadata).用户数据,即文件数据块 ...