from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
from tensorflow.python.keras.models import Sequential,Model
from tensorflow.python.keras.layers import Dense,Flatten,Input
import tensorflow as tf
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python import keras
import os
import numpy as np class SingleNN(object): #建立神经网络模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28,28)),
keras.layers.Dense(128,activation=tf.nn.relu),
keras.layers.Dense(10,activation=tf.nn.softmax)
]) def __init__(self):
(self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
#归一化
self.x_train = self.x_train/255.0
self.x_test = self.x_test/255.0 def singlenn_compile(self):
'''
编译模型优化器、损失、准确率
:return:
'''
SingleNN.model.compile(
optimizer=keras.optimizers.SGD(lr=0.01),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
) def singlenn_fit(self):
"""
进行fit训练
:return:
"""
# modelcheck = keras.callbacks.ModelCheckpoint("./ckpt/singlenn_{epoch:02d}-{acc:.2f}.h5",
# # monitor="val_acc", #保存损失还是准确率
# # save_best_only=True,
# save_weights_only=True,
# mode = 'auto',
# period = 1
# )
board = keras.callbacks.TensorBoard(log_dir="./graph",write_graph=True)
SingleNN.model.fit(self.x_train,self.y_train,epochs=5,callbacks=[board]) def single_evalute(self):
'''
模型评估
:return:
'''
test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
print(test_loss,test_acc) def single_predict(self):
'''
预测结果
:return:
'''
# if os.path.exists("./ckpt/checkpoink"):
# SingleNN.model.load_weights("./ckpt/SingleNN") if os.path.exists("./ckpt/SingleNN.h5"):
SingleNN.model.load_weights("./ckpt/SingleNN.h5") predictions = SingleNN.model.predict(self.x_test) return predictions if __name__ == '__main__':
snn = SingleNN()
snn.singlenn_compile()
snn.singlenn_fit()
snn.single_evalute()
# # SingleNN.model.save_weights("./ckpt/SingleNN")
# SingleNN.model.save_weights("./ckpt/SingleNN.h5")
# predictions = snn.single_predict()
# print(predictions)
# result = np.argmax(predictions,axis=1)
# print(result)

  

TensorFlow-keras fit的callbacks参数,定值保存模型的更多相关文章

  1. TensorFlow笔记四:从生成和保存模型 -> 调用使用模型

    TensorFlow常用的示例一般都是生成模型和测试模型写在一起,每次更换测试数据都要重新训练,过于麻烦, 以下采用先生成并保存本地模型,然后后续程序调用测试. 示例一:线性回归预测 make.py ...

  2. Keras(一)Sequential与Model模型、Keras基本结构功能

    keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...

  3. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  4. 转sklearn保存模型

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  5. [TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用

    前言 是的,除了水报错文,我也来写点其他的.本文主要介绍Keras中以下三个函数的用法: fit()fit_generator()train_on_batch()当然,与上述三个函数相似的evalua ...

  6. TensorFlow 训练好模型参数的保存和恢复代码

    TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧. TensorFlow 提供了 Saver 类,可以进行保存和恢 ...

  7. Deep Learning 32: 自己写的keras的一个callbacks函数,解决keras中不能在每个epoch实时显示学习速率learning rate的问题

    一.问题: keras中不能在每个epoch实时显示学习速率learning rate,从而方便调试,实际上也是为了调试解决这个问题:Deep Learning 31: 不同版本的keras,对同样的 ...

  8. 100天搞定机器学习|day40-42 Tensorflow Keras识别猫狗

    100天搞定机器学习|1-38天 100天搞定机器学习|day39 Tensorflow Keras手写数字识别 前文我们用keras的Sequential 模型实现mnist手写数字识别,准确率0. ...

  9. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

随机推荐

  1. UVA11987 Almost Union-Find 并查集的节点删除

    题意: 第一行给出一个n,m,表示 n个集合,里面的元素为1~n,下面有m种操作,第一个数为 1 时,输入a,b 表示a,b 两个集合合并到一起,第一个数为 2 时,输入a,b表示将 a 从他原来的集 ...

  2. 下面总结一些在HTML中经常使用到的快捷键

    使用的编辑器是VS code: 首先是很基础的: ctrl+s  :保存: ctrl+a  :  全选: ctrl+c , ctrl+c , ctrl+v : 剪切,复制,粘贴: ctrl+z ,ct ...

  3. Springboot使用自定义注解实现简单参数加密解密(注解+HandlerMethodArgumentResolver)

    前言 我黄汉三又回来了,快半年没更新博客了,这半年来的经历实属不易,疫情当头,本人实习的公司没有跟员工共患难, 直接辞掉了很多人.作为一个实习生,本人也被无情开除了.所以本人又得重新准备找工作了. 算 ...

  4. flask中 多对多的关系 主从表之间的的增删改查

    # 角色表模型class Role(db.Model): r_id = db.Column(db.Integer, primary_key=True) r_name = db.Column(db.St ...

  5. Ubuntu系统在Anaconda中安装Python3.6的虚拟环境

    原因:Anaconda的python版本是3.7的,TensorFlow尚不支持此版本,于是我们创建一个Python的虚拟环境以支持TensorFlow 创建tf环境 conda create --n ...

  6. Vue 里面对树状数组进行增删改查 的方法

    [{"id":"5e4c3b02fc984961a17607c37712eae0", "optLock":0, "parentId ...

  7. XDebug的配置和使用

    简介 XDebug是一个开放源代码的PHP程序调试器(即一个Debug工具) 可以用来跟踪,调试和分析PHP程序的运行状况 功能强大的神器,对审计有非常大的帮助. 官网:http://www.xdeb ...

  8. F - F HDU - 1173(二维化一维-思维)

    F - F HDU - 1173 一个邮递员每次只能从邮局拿走一封信送信.在一个二维的直角坐标系中,邮递员只能朝四个方向移动,正北.正东.正南.正西. 有n个需要收信的地址,现在需要你帮助找到一个地方 ...

  9. 我是如何从通信转到Java软件开发工程师的?

    我的读者里面有绝大部分都是在校学生,有本科的,也有专科的,我在微信里收到很多读者的提问,大部分问题都跟如何学习编程有关,有换专业自学的.有迷茫不知道如何学习的.有报培训班没啥效果的等等,我能感受到他们 ...

  10. Pytest系列(7) - skip、skipif跳过用例

    如果你还想从头学起Pytest,可以看看这个系列的文章哦! https://www.cnblogs.com/poloyy/category/1690628.html 前言 pytest.mark.sk ...