# -*- coding=utf-8 -*-
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense,Flatten,Dropout
from keras.optimizers import Adadelta
from keras.datasets import cifar10
from keras import applications

import matplotlib.pyplot as plt
%matplotlib inline

vgg_model=applications.VGG19(include_top=False,weights='imagenet')
vgg_model.summary()

(train_x,train_y),(test_x,test_y)=cifar10.load_data()
print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

n_classes=10
train_y=keras.utils.to_categorical(train_y,n_classes)
test_y=keras.utils.to_categorical(test_y,n_classes)

bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

my_model=Sequential()
my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
my_model.add(Dense(512,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(256,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(n_classes,activation='softmax'))
my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",\
metrics=['accuracy'])
my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
print("loss:",evaluation[0],"accuracy:",evaluation[1])

def predict_label(img_idx,show_proba=True):
plt.imshow(train_x[img_idx],aspect='auto')
plt.title("Image to be labeled")
plt.show()
img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
print("Actual class:{0}\nPredict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

if show_proba:
pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
print(pred)

for i in range(3):
predict_label(i,show_proba=True)

吴裕雄 python神经网络(8)的更多相关文章

  1. 吴裕雄 python神经网络 花朵图片识别(10)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  2. 吴裕雄 python神经网络 花朵图片识别(9)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  3. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  4. 吴裕雄 python神经网络 水果图片识别(4)

    # coding: utf-8 # In[1]:import osimport numpy as npfrom skimage import color, data, transform, io # ...

  5. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  6. 吴裕雄 python神经网络 水果图片识别(2)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom skimage import color,data,transform,i ...

  7. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  10. 吴裕雄 python 神经网络——TensorFlow实现搭建基础神经网络

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

随机推荐

  1. Fragment重叠,使用show和hide控制显示和隐藏界面重叠问题;

    Fragment重叠原因: App因某种原因崩溃自动重启,或App长期在后台运行Fragment依赖的activity被回收等等原因:当系统内存不足,Fragment 的依附Activity 被回收的 ...

  2. [ORACLE]java.sql.SQLRecoverableException: IO Error: Connection rese

    随机数引起的阻塞问题 程序通过 java -jar -Djava.security.egd=file:/dev/./urandom xxx 的方式执行, http://hongjiang.info/j ...

  3. 初级安全入门——SQL注入的原理与利用

    工具简介 SQLMAP: 一个开放源码的渗透测试工具,它可以自动探测和利用SQL注入漏洞来接管数据库服务器.它配备了一个强大的探测引擎,为最终渗透测试人员提供很多强大的功能,可以拖库,可以访问底层的文 ...

  4. Access 分页

    access分页 pageSize 每页显示多少条数据 pageNumber 页数 从客户端传来 pages) SQL语句 select top pageSize * from 表名 where id ...

  5. C# sqlserver winform

    //public static readonly string LocalSqlServer = System.Configuration.ConfigurationManager.AppSettin ...

  6. spring boot 自定义视图路径

    boot 自定义访问视图路径 . 配置文件 目录结构 启动类: html页面 访问: 覆盖boot默认路径引用. 如果没有重新配置,则在pom引用模板. 修改配置文件. 注意一定要编译工程

  7. 《算法》第四章部分程序 part 10

    ▶ 书中第四章部分程序,包括在加上自己补充的代码,包括无向图连通分量,Kosaraju - Sharir 算法.Tarjan 算法.Gabow 算法计算有向图的强连通分量 ● 无向图连通分量 pack ...

  8. Maven子模块

    1.选取父工程创建子模块(Maven Modeule) 2.创建子模块时 Packaging 选 jar

  9. Nginx配置HTTPS证书网站

    前提: 1.主机需要先安装openssl     2.编译安装nginx时,要加上--with-http_ssl_module  这个ssl模块 现在开始配置:(我当时配置时,主机已安装了openss ...

  10. Linux 创建用户并赋予 Sudo 权限

    01,创建账号 => useradd admin 02,赋予密码 => passwd admin 03,修改 sudo 权限文件,使得该用户可以使用 sudo 命令 vim /etc/su ...