导入必要的库:

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
tf.random.set_seed(2345)

其中os.environ部分是为了减少Tensorflow打印的信息

构建网络结构:

conv_layers=[
layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2,2],strides=2,padding="same"), layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"), layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"), layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"), layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),
]

优化器:

def preprocess(x,y):
x=tf.cast(x,dtype=tf.float32)/255.
y=tf.cast(y,dtype=tf.int32)
return x,y

加载数据:

这里使用比较常见的CIFAR10的数据集

(x_train,y_train),(x_test,y_test)=datasets.cifar10.load_data()
y_train=tf.squeeze(y_train,axis=1)
y_test=tf.squeeze(y_test,axis=1)
# print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
train_data=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_data=train_data.shuffle(1000).map(preprocess).batch(64) test_data=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_data=test_data.map(preprocess).batch(64) sample=next(iter(train_data))
print('sample:',sample[0].shape,sample[1].shape,
tf.reduce_min(sample[0]),tf.reduce_max(sample[0]))

sample=next(iter(train_data))

这一部分是打印train_data的信息

完善网络:

def main():
conv_net=Sequential(conv_layers)
# x=tf.random.normal([4,32,32,3])
# out=conv_net(x)
# print(out.shape)
fc_net=Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10,activation=None),
])
conv_net.build(input_shape=[None, 32, 32, 3])
fc_net.build(input_shape=[None,512])
optimizer=optimizers.Adam(lr=1e-4)

计算loss:

variables=conv_net.trainable_variables+fc_net.trainable_variables
for epoch in range(50):
for step,(x,y) in enumerate(train_data):
with tf.GradientTape() as tape:
out=conv_net(x)
out=tf.reshape(out,[-1,512])
logits=fc_net(out)
y_onehot=tf.one_hot(y,depth=10)
loss=tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss=tf.reduce_mean(loss)
grads=tape.gradient(loss,variables)
optimizer.apply_gradients(zip(grads,variables))
if step%100==0:
print(epoch,step,'loss',float(loss))

测试:

total_num=0
total_correct=0
for x,y in test_data:
out=conv_net(x)
out=tf.reshape(out,[-1,512])
logits=fc_net(out)
prob=tf.nn.softmax(logits,axis=1)
pred=tf.argmax(prob,axis=1)
pred=tf.cast(pred,dtype=tf.int32)
correct=tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct=tf.reduce_sum(correct)
total_num+=x.shape[0]
total_correct+=int(correct)
acc=total_correct/total_num
print(epoch,'acc:',acc)
if __name__ == '__main__':
main()

训练数据:

0 0 loss 2.302990436553955
0 100 loss 1.9521405696868896
0 200 loss 1.9435423612594604
0 300 loss 1.6067744493484497
0 400 loss 1.5959546566009521
0 500 loss 1.734712839126587
0 600 loss 1.2384529113769531
0 700 loss 1.3307044506072998
0 acc: 0.4787
5 0 loss 0.6936513185501099
5 100 loss 0.7874761819839478
5 200 loss 0.7884306907653809
5 300 loss 0.6663026809692383
5 400 loss 0.4075947105884552
5 500 loss 0.6752095222473145
5 600 loss 0.5246847867965698
5 700 loss 0.5275574922561646
5 acc: 0.7299
10 0 loss 0.7874808311462402
10 100 loss 0.5072851181030273
10 200 loss 0.4451877772808075
10 300 loss 0.177499920129776
10 400 loss 0.13723205029964447
10 500 loss 0.2971668243408203
10 600 loss 0.25279730558395386
10 700 loss 0.36453887820243835
10 acc: 0.7355
15 0 loss 0.2800075113773346
15 100 loss 0.1841358095407486
15 200 loss 0.040746696293354034
15 300 loss 0.06615383923053741
15 400 loss 0.1183178648352623
15 500 loss 0.07481158524751663
15 600 loss 0.09398414194583893
15 700 loss 0.03665520250797272
15 acc: 0.7469
20 0 loss 0.02290465496480465
20 100 loss 0.008633529767394066
20 200 loss 0.21534058451652527
20 300 loss 0.011568240821361542
20 400 loss 0.08179830759763718
20 500 loss 0.02673691138625145
20 600 loss 0.06506452709436417
20 700 loss 0.026200752705335617
20 acc: 0.7621

训练大概50epoch,这里仅仅展示20个,可以看到,验证准确率是在不断的上升的,后面的数据就不展示了,我也没训练完,有兴趣的可以接着跑将模型保存一下,有时间再接着训练

Tensorflow2.0实现VGG13的更多相关文章

  1. 基于tensorflow2.0 使用tf.keras实现Fashion MNIST

    本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...

  2. Google工程师亲授 Tensorflow2.0-入门到进阶

    第1章 Tensorfow简介与环境搭建 本门课程的入门章节,简要介绍了tensorflow是什么,详细介绍了Tensorflow历史版本变迁以及tensorflow的架构和强大特性.并在Tensor ...

  3. TensorFlow2.0(1):基本数据结构—张量

    1 引言 TensorFlow2.0版本已经发布,虽然不是正式版,但预览版都发布了,正式版还会远吗?相比于1.X,2.0版的TensorFlow修改的不是一点半点,这些修改极大的弥补了1.X版本的反人 ...

  4. 『TensorFlow2.0正式版教程』极简安装TF2.0正式版(CPU&GPU)教程

    0 前言 TensorFlow 2.0,今天凌晨,正式放出了2.0版本. 不少网友表示,TensorFlow 2.0比PyTorch更好用,已经准备全面转向这个新升级的深度学习框架了. ​ 本篇文章就 ...

  5. 『TensorFlow2.0正式版』TF2.0+Keras速成教程·零:开篇简介与环境准备

    此篇教程参考自TensorFlow 2.0 + Keras Crash Course,在原文的基础上进行了适当的总结与改编,以适应于国内开发者的理解与使用,水平有限,如果写的不对的地方欢迎大家评论指出 ...

  6. TensorFlow2.0(9):TensorBoard可视化

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  7. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  8. tensorflow2.0安装

    版本: python3.5 Anaconda 4.2.0 tensorflow2.0 cpu版本 1.安装命令 pip3 install tensorflow==2.0.0.0a0 -i https: ...

  9. TensorFlow2.0初体验

    TF2.0默认为动态图,即eager模式.意味着TF能像Pytorch一样不用在session中才能输出中间参数值了,那么动态图和静态图毕竟是有区别的,tf2.0也会有写法上的变化.不过值得吐槽的是, ...

  10. tensorflow2.0 学习(三)

    用tensorflow2.0 版回顾了一下mnist的学习 代码如下,感觉这个版本下的mnist学习更简洁,更方便 关于tensorflow的基础知识,这里就不更新了,用到什么就到网上取搜索相关的知识 ...

随机推荐

  1. 你能看到这个汉字么“  ” ?关于Unicode的私人使用区(PUA) 和浏览器端显示处理

    如果你现在使用的是chrome查看那么你是看不到我标题中的汉字的,显示为一个小方框,但是你使用edge查看的话,这个字就能正常的显示出来,不信你试试! 本故事源于我在做数据过程中遇到Unicode编码 ...

  2. Nextcloud 维护管理

    Nextcloud 维护管理 目录 Nextcloud 维护管理 1.管理员被禁用怎么办 2.管理员密码忘了怎么办 1.管理员被禁用怎么办 通过命令行解禁管理员用户: 方法一:通过命令行解禁管理员用户 ...

  3. vue3探索——pinia高阶使用

    以下是一些 Pinia 的其他高阶功能: storeToRefs():响应式解构仓库,保证解构出来的数据是响应式的数据. 状态持久化:Pinia 并没有内置的状态持久化功能,但你可以使用第三方库或自定 ...

  4. Java 21 新特性:switch的模式匹配

    在之前的Java 17新特性中,我们介绍过关于JEP 406: switch的模式匹配,但当时还只是关于此内容的首个预览版本.之后在JDK 18.JDK 19.JDK 20中又都进行了更新和完善.如今 ...

  5. Solution -「CF 724F」Uniformly Branched Trees

    Description Link. 给定三个数 \(n,d,mod\),求有多少种 \(n\) 个点的不同构的树满足:除了度数为 \(1\) 的结点外,其余结点的度数均为 \(d\).答案对质数 \( ...

  6. 记一次 .NET 某仪器测量系统 CPU爆高分析

    一:背景 1. 讲故事 最近也挺奇怪,看到了两起 CPU 爆高的案例,且诱因也是一致的,觉得有一些代表性,合并分享出来帮助大家来避坑吧,闲话不多说,直接上 windbg 分析. 二:WinDbg 分析 ...

  7. es针对nested类型数据无法进行过滤查询的问题记录

    问题描述 es中存在有一个名为task_data_1的索引,其字段映射关系如下所示: { "task_data_1" : { "mappings" : { &q ...

  8. Cython加密python代码防止反编译

    本方法适用于Linux环境下: 1.安装库Cython pip3 install Cython==3.0.0a10 2.编写待加密文件:hello.py import random def ac(): ...

  9. 手把手教你写一个JSON在线解析的前端网站1

    前言 作为一名Android开发,经常要跟后端同事联调接口,那么总避免不了要格式化接口返回值,将其转换为清晰直观高亮的UI样式以及折叠部分内容,方便我们查看定位关键的信息. 一直以来都是打开Googl ...

  10. 铅华洗尽,粉黛不施,人工智能AI基于ProPainter技术去除图片以及视频水印(Python3.10)

    视频以及图片修复技术是一项具有挑战性的AI视觉任务,它涉及在视频或者图片序列中填补缺失或损坏的区域,同时保持空间和时间的连贯性.该技术在视频补全.对象移除.视频恢复等领域有广泛应用.近年来,两种突出的 ...