from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
from tensorflow.python.keras.models import Sequential,Model
from tensorflow.python.keras.layers import Dense,Flatten,Input
import tensorflow as tf
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python import keras
import os
import numpy as np class SingleNN(object): #建立神经网络模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28,28)),
keras.layers.Dense(128,activation=tf.nn.relu),
keras.layers.Dense(10,activation=tf.nn.softmax)
]) def __init__(self):
(self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
#归一化
self.x_train = self.x_train/255.0
self.x_test = self.x_test/255.0 def singlenn_compile(self):
'''
编译模型优化器、损失、准确率
:return:
'''
SingleNN.model.compile(
optimizer=keras.optimizers.SGD(lr=0.01),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
) def singlenn_fit(self):
"""
进行fit训练
:return:
"""
# modelcheck = keras.callbacks.ModelCheckpoint("./ckpt/singlenn_{epoch:02d}-{acc:.2f}.h5",
# # monitor="val_acc", #保存损失还是准确率
# # save_best_only=True,
# save_weights_only=True,
# mode = 'auto',
# period = 1
# )
board = keras.callbacks.TensorBoard(log_dir="./graph",write_graph=True)
SingleNN.model.fit(self.x_train,self.y_train,epochs=5,callbacks=[board]) def single_evalute(self):
'''
模型评估
:return:
'''
test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
print(test_loss,test_acc) def single_predict(self):
'''
预测结果
:return:
'''
# if os.path.exists("./ckpt/checkpoink"):
# SingleNN.model.load_weights("./ckpt/SingleNN") if os.path.exists("./ckpt/SingleNN.h5"):
SingleNN.model.load_weights("./ckpt/SingleNN.h5") predictions = SingleNN.model.predict(self.x_test) return predictions if __name__ == '__main__':
snn = SingleNN()
snn.singlenn_compile()
snn.singlenn_fit()
snn.single_evalute()
# # SingleNN.model.save_weights("./ckpt/SingleNN")
# SingleNN.model.save_weights("./ckpt/SingleNN.h5")
# predictions = snn.single_predict()
# print(predictions)
# result = np.argmax(predictions,axis=1)
# print(result)

  

TensorFlow-keras fit的callbacks参数,定值保存模型的更多相关文章

  1. TensorFlow笔记四:从生成和保存模型 -> 调用使用模型

    TensorFlow常用的示例一般都是生成模型和测试模型写在一起,每次更换测试数据都要重新训练,过于麻烦, 以下采用先生成并保存本地模型,然后后续程序调用测试. 示例一:线性回归预测 make.py ...

  2. Keras(一)Sequential与Model模型、Keras基本结构功能

    keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...

  3. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  4. 转sklearn保存模型

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  5. [TensorFlow 2] [Keras] fit()、fit_generator() 和 train_on_batch() 分析与应用

    前言 是的,除了水报错文,我也来写点其他的.本文主要介绍Keras中以下三个函数的用法: fit()fit_generator()train_on_batch()当然,与上述三个函数相似的evalua ...

  6. TensorFlow 训练好模型参数的保存和恢复代码

    TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧. TensorFlow 提供了 Saver 类,可以进行保存和恢 ...

  7. Deep Learning 32: 自己写的keras的一个callbacks函数,解决keras中不能在每个epoch实时显示学习速率learning rate的问题

    一.问题: keras中不能在每个epoch实时显示学习速率learning rate,从而方便调试,实际上也是为了调试解决这个问题:Deep Learning 31: 不同版本的keras,对同样的 ...

  8. 100天搞定机器学习|day40-42 Tensorflow Keras识别猫狗

    100天搞定机器学习|1-38天 100天搞定机器学习|day39 Tensorflow Keras手写数字识别 前文我们用keras的Sequential 模型实现mnist手写数字识别,准确率0. ...

  9. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

随机推荐

  1. 记录一个不同的流媒体网站实现方法,和用Python爬虫爬它的坑

    今天找到一片电影,想把它下载下来. 先开Networks工具分析一下: 初步分析发现,视频加载时会拉取TS格式的文件,推测这是一个m3u8的索引,记录着几百段TS文件,这样方便快进时加载. 但是实际分 ...

  2. Java内存模型和ConcurrentHashMap 1.7源码分析

    简介 ConcurrentHashMap 是 util.concurrent 包的重要成员.本文将结合 Java 内存模型,分析 JDK 源代码,探索 ConcurrentHashMap 高并发的具体 ...

  3. HashCode()与equals()深入理解

    1.hashCode()和equals()方法都是Object类提供的方法, hashCode()返回该对象的哈希码值,该值通常是一个由该对象的内部地址转换而来的int型整数, Object的equa ...

  4. 【每周小项目】使用 puppeteer 插件爬取动态网站

    目录 0. 前言 问题 解决 1. 下载与引包 2. 使用步骤 3. 爬过的几个坑 page.evaluate 的传参问题 元素操作问题 0. 前言 这两天对爬虫开始感兴趣,最开始是源于天涯的一个房价 ...

  5. idea打包或编译错误,错误为c盘idea路径某些文件被占用(非idea文件,项目生成的文件)

    方法列表(2的效果可能更好) 1.将被占用的文件删除之后,重新打包或编译. 2.多编译几次项目. 3.发现真正可能的原因.(貌似被南航企业版360拦截了,导致targe或maven等文件被占用问题) ...

  6. C 旅店

    时间限制 : - MS   空间限制 : - KB  评测说明 : 1s,256m 问题描述 一条笔直的公路旁有N家旅店,从左往右编号1到N,其中第i家旅店的位置坐标为Xi.旅人何老板总在赶路.他白天 ...

  7. H - 覆盖的面积(线段树-线段扫描 + 离散化(板题))

    给定平面上若干矩形,求出被这些矩形覆盖过至少两次的区域的面积. Input 输入数据的第一行是一个正整数T(1<=T<=100),代表测试数据的数量.每个测试数据的第一行是一个正整数N(1 ...

  8. override 重写

    //override:子类继承父类,子类重写父类的方法 public class override { public static void main(String[] args) { horse h ...

  9. Redis 设计与实现笔记 - SDS

    Redis 中的字符串没有使用 C语言中的字符指针(char *),而是使用了自定义的结构 sds. 文件: sds.h sds.c 结构: struct sdshdr { int len; // 填 ...

  10. docker下安装centos,并在其上搭建lnmp环境

    一.安装CentOs容器 1.进入docker下载CentOs,这里我使用的CentOs6.8 docker pull centos:6.8 2.创建容器 sudo docker run --priv ...