TensorFlow-keras fit的callbacks参数,定值保存模型
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
from tensorflow.python.keras.models import Sequential,Model
from tensorflow.python.keras.layers import Dense,Flatten,Input
import tensorflow as tf
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python import keras
import os
import numpy as np class SingleNN(object): #建立神经网络模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28,28)),
keras.layers.Dense(128,activation=tf.nn.relu),
keras.layers.Dense(10,activation=tf.nn.softmax)
]) def __init__(self):
(self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
#归一化
self.x_train = self.x_train/255.0
self.x_test = self.x_test/255.0 def singlenn_compile(self):
'''
编译模型优化器、损失、准确率
:return:
'''
SingleNN.model.compile(
optimizer=keras.optimizers.SGD(lr=0.01),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
) def singlenn_fit(self):
"""
进行fit训练
:return:
"""
# modelcheck = keras.callbacks.ModelCheckpoint("./ckpt/singlenn_{epoch:02d}-{acc:.2f}.h5",
# # monitor="val_acc", #保存损失还是准确率
# # save_best_only=True,
# save_weights_only=True,
# mode = 'auto',
# period = 1
# )
board = keras.callbacks.TensorBoard(log_dir="./graph",write_graph=True)
SingleNN.model.fit(self.x_train,self.y_train,epochs=5,callbacks=[board]) def single_evalute(self):
'''
模型评估
:return:
'''
test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
print(test_loss,test_acc) def single_predict(self):
'''
预测结果
:return:
'''
# if os.path.exists("./ckpt/checkpoink"):
# SingleNN.model.load_weights("./ckpt/SingleNN") if os.path.exists("./ckpt/SingleNN.h5"):
SingleNN.model.load_weights("./ckpt/SingleNN.h5") predictions = SingleNN.model.predict(self.x_test) return predictions if __name__ == '__main__':
snn = SingleNN()
snn.singlenn_compile()
snn.singlenn_fit()
snn.single_evalute()
# # SingleNN.model.save_weights("./ckpt/SingleNN")
# SingleNN.model.save_weights("./ckpt/SingleNN.h5")
# predictions = snn.single_predict()
# print(predictions)
# result = np.argmax(predictions,axis=1)
# print(result)
TensorFlow-keras fit的callbacks参数,定值保存模型的更多相关文章
- TensorFlow笔记四:从生成和保存模型 -> 调用使用模型
TensorFlow常用的示例一般都是生成模型和测试模型写在一起,每次更换测试数据都要重新训练,过于麻烦, 以下采用先生成并保存本地模型,然后后续程序调用测试. 示例一:线性回归预测 make.py ...
- Keras(一)Sequential与Model模型、Keras基本结构功能
keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...
- sklearn保存模型-【老鱼学sklearn】
训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...
- 转sklearn保存模型
训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...
- [TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用
前言 是的,除了水报错文,我也来写点其他的.本文主要介绍Keras中以下三个函数的用法: fit()fit_generator()train_on_batch()当然,与上述三个函数相似的evalua ...
- TensorFlow 训练好模型参数的保存和恢复代码
TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧. TensorFlow 提供了 Saver 类,可以进行保存和恢 ...
- Deep Learning 32: 自己写的keras的一个callbacks函数,解决keras中不能在每个epoch实时显示学习速率learning rate的问题
一.问题: keras中不能在每个epoch实时显示学习速率learning rate,从而方便调试,实际上也是为了调试解决这个问题:Deep Learning 31: 不同版本的keras,对同样的 ...
- 100天搞定机器学习|day40-42 Tensorflow Keras识别猫狗
100天搞定机器学习|1-38天 100天搞定机器学习|day39 Tensorflow Keras手写数字识别 前文我们用keras的Sequential 模型实现mnist手写数字识别,准确率0. ...
- Keras框架下的保存模型和加载模型
在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...
随机推荐
- 记录一个不同的流媒体网站实现方法,和用Python爬虫爬它的坑
今天找到一片电影,想把它下载下来. 先开Networks工具分析一下: 初步分析发现,视频加载时会拉取TS格式的文件,推测这是一个m3u8的索引,记录着几百段TS文件,这样方便快进时加载. 但是实际分 ...
- Java内存模型和ConcurrentHashMap 1.7源码分析
简介 ConcurrentHashMap 是 util.concurrent 包的重要成员.本文将结合 Java 内存模型,分析 JDK 源代码,探索 ConcurrentHashMap 高并发的具体 ...
- HashCode()与equals()深入理解
1.hashCode()和equals()方法都是Object类提供的方法, hashCode()返回该对象的哈希码值,该值通常是一个由该对象的内部地址转换而来的int型整数, Object的equa ...
- 【每周小项目】使用 puppeteer 插件爬取动态网站
目录 0. 前言 问题 解决 1. 下载与引包 2. 使用步骤 3. 爬过的几个坑 page.evaluate 的传参问题 元素操作问题 0. 前言 这两天对爬虫开始感兴趣,最开始是源于天涯的一个房价 ...
- idea打包或编译错误,错误为c盘idea路径某些文件被占用(非idea文件,项目生成的文件)
方法列表(2的效果可能更好) 1.将被占用的文件删除之后,重新打包或编译. 2.多编译几次项目. 3.发现真正可能的原因.(貌似被南航企业版360拦截了,导致targe或maven等文件被占用问题) ...
- C 旅店
时间限制 : - MS 空间限制 : - KB 评测说明 : 1s,256m 问题描述 一条笔直的公路旁有N家旅店,从左往右编号1到N,其中第i家旅店的位置坐标为Xi.旅人何老板总在赶路.他白天 ...
- H - 覆盖的面积(线段树-线段扫描 + 离散化(板题))
给定平面上若干矩形,求出被这些矩形覆盖过至少两次的区域的面积. Input 输入数据的第一行是一个正整数T(1<=T<=100),代表测试数据的数量.每个测试数据的第一行是一个正整数N(1 ...
- override 重写
//override:子类继承父类,子类重写父类的方法 public class override { public static void main(String[] args) { horse h ...
- Redis 设计与实现笔记 - SDS
Redis 中的字符串没有使用 C语言中的字符指针(char *),而是使用了自定义的结构 sds. 文件: sds.h sds.c 结构: struct sdshdr { int len; // 填 ...
- docker下安装centos,并在其上搭建lnmp环境
一.安装CentOs容器 1.进入docker下载CentOs,这里我使用的CentOs6.8 docker pull centos:6.8 2.创建容器 sudo docker run --priv ...