tflearn 在每一个epoch完毕保存模型
关键代码:
tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
我的demo:
def get_model(width, height, classes=40):
# TODO, modify model
network = input_data(shape=[None, width, height, 3]) # if RGB, 224,224,3
# Residual blocks
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18
n = 2
net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, n, 16)
net = tflearn.residual_block(net, 1, 32, downsample=True)
net = tflearn.residual_block(net, n-1, 32)
net = tflearn.residual_block(net, 1, 64, downsample=True)
net = tflearn.residual_block(net, n-1, 64)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, classes, activation='softmax')
#mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)
net = tflearn.regression(net, optimizer=mom,
loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
return model def main():
trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
#trainX = trainX.reshape([-1, width, height, 1])
#testX = testX.reshape([-1, width, height, 1])
print("sample data:")
print(trainX[0])
print(trainY[0])
print(testX[-1])
print(testY[-1]) model = get_model(width, height, classes=3755) filename = 'tflearn_resnet/model.tflearn'
# try to load model and resume training
try:
#model.load(filename)
model.load("model_resnet_cifar10-195804")
print("Model loaded OK. Resume training!")
except:
pass early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)
try:
model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')
except StopIteration as e:
print("OK, stop iterate!Good!") model.save(filename) del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
filename = 'tflearn_resnet/model-infer.tflearn'
model.save(filename)
tflearn 在每一个epoch完毕保存模型的更多相关文章
- pytorch加载和保存模型
在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢? 方法一(推荐): 第一种方法也 ...
- pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件
本文分为两部分,第一部分讲如何保存模型参数,优化器参数等等,第二部分则讲如何读取. 假设网络为model = Net(), optimizer = optim.Adam(model.parameter ...
- Socket编程模型之完毕port模型
转载请注明来源:viewmode=contents">http://blog.csdn.net/caoshiying?viewmode=contents 一.回想重叠IO模型 用完毕例 ...
- ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档]
ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档] 简介 简单地说就是该有的都有了,但是总体跑起来效果还不好. 还在开发中,它工作的效果还不好.但是你可以直 ...
- TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人
简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...
- sklearn保存模型-【老鱼学sklearn】
训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...
- PyTorch保存模型与加载模型+Finetune预训练模型使用
Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...
- (原)tensorflow保存模型及载入保存的模型
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7198773.html 参考网址: http://stackoverflow.com/questions ...
- 转sklearn保存模型
训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...
随机推荐
- 怎样使用 iOS 7 的 AVSpeechSynthesizer 制作有声书(3)
plist 中的每一页 utteranceSting 我们都创建了一个RWTPage.displayText.因此,每页的文本会一次性地显示出来. 由于 You've constructedeach ...
- ext tree展开时的一些技巧
加入子节点的时候.我们须要展开父节点.并选中刚加入好的节点. 这时候会有一个问题. 我用的ext-js-4.2起码有一种问题. 节点内部会混乱.要么多加一个. 要么层级会发生故障. 随后我发现一个窍门 ...
- OC中动态创建可变数组的问题.有一个数组,数组中有13个元素,先将该数组进行分组,每3个元素为一组,分为若干组,最后用一个数组统一管理这些分组.(要动态创建数组).两种方法
<span style="font-size:24px;">//////第一种方法 // NSMutableArray *arr = [NSMutableArray a ...
- Web前端开发--JS技术大梳理
什么是JS JavaScript是一种直译式脚本语言,是一种动态类型.弱类型.基于原型的语言,内置支持类型.它的解释器被称为JavaScript引擎,为浏览器的一部分,广泛用于客户端的脚本语 ...
- [转]Linux shell中的那些小把戏
我日常使用Linux shell(Bash),但是我经常忘记一些有用的命令或者shell技巧.是的,我能记住一些命令,但是肯定不会只在特定的任务上使用一次,所以我就开始在我的Dropbox账号里用文本 ...
- JAVA实现KNN分类
转载请注明出处:http://blog.csdn.net/xiaojimanman/article/details/51064307 http://www.llwjy.com/blogdetail/f ...
- Nginx 一些常用的URL 重写方法
url重写应该不陌生,不管是SEO URL 伪静态的需要,还是在非常流行的wordpress里,重写无处不在. 1. 在 Apache 的写法 RewriteCond %{HTTP_HOST} n ...
- Redis(六):java里常用的redis客户端(Jedis和Redisson)
Redis的各种语言客户端列表,请参见Redis Client.其中Java客户端在github上start最高的是Jedis和Redisson.Jedis提供了完整Redis命令,而Redisson ...
- grunt使用一步一步讲解
grunt 是一套前端自动化工具,一个基于nodeJs的命令行工具,一般用于:① 压缩文件② 合并文件③ 简单语法检查 对于其他用法,我还不太清楚,我们这里简单介绍下grunt的压缩.合并文件,初学, ...
- PHP下最好用的富文本HTML过滤器:HTMLPurifier使用教程
HTMLPurifier是我目前用过最好的PHP富文本HTML过滤器了,采用了白名单机制,有效杜绝了用户提交表单中的非法HTML标签,从而可以防止XSS攻击! HTMLPurifier项目地址:htt ...