现在一直在用TensorFlow训练CNN和LSTM神经网络,但是训练期间遇到了好多坑,现就遇到的各种坑做一下总结

1.问题一;训练CNN的时候出现nan

CNN是我最开始接触的网络,我的研究课题就是利用CNN,LSTM等网络对人体动作做识别。动作数据来源于手机的加速度计,做动作的人在固定位置携带手机并做特定动作,实验人员接收手机的加速度计数值并打上特定的动作标签。

在训练CNN网络时一共遇到两处坑,一是遇到在训练期间遇到nan错误,这个错误很常见。nan的错误多源于你的学习率设置的太大或者batchsize设置的太大,可以10倍10倍的减小学习率直到nan错误不出现。其实要弄明白nan错误怎么出现的才能真正的解决这个错误。

这个错误是因为logits输出太大变成INF,对这个取log就会在求梯度就会变成nan,nan是not a number 的缩写,表示不是一个有理数。所以除了调小学习率这个解决方案,另一个解决方案还可以给loss加正则化项。

2.问题二;在训练CNN的时候loss一直降不下去

我用普通的卷积核对数据做卷积,准确率保持在0.4左右不动了,loss也是到后面基本不降,各种调网络结构终于找到一种深度可分离卷积结构可以使准确率达到0.9左右。深度卷积与我们了解的普通卷积网络的不同点就是它只做单通道卷积,也就是一次卷积后各个通道的卷积结果不相加,各自独立做为一个featuremap。但是这种网络结构的神经元个数增长很快,所以它的训练会比普通卷积慢很多,对于通道数过多的数据并不是很适合,但是这种网络的鲁棒性很好。

其实我这里出现loss不降的原因我后来找到了,是因为我的数据里有错误,数据本身就含有了nan的数据,错误的数据导致网络无法收敛

3.问题三;训练LSTM网络循环出现nan的错误,一到特定的batch就出现nan

一出现这个问题我就意识到肯定是数据出了问题,后来把数据中的nan替换掉就正常了。因为我的数据加载完后是numpy格式的,替换只需要加一句train= np.nan_to_num(train)就可以了

对数据做一个这样的简单处理,普通卷积网也正常了,真是豁然开朗啊,不用深度卷积,只用普通卷积就可以达到0.96的准确率。

4.我后来把CNN和LSTM联合起来对动作识别,发现不对数据处理nan错误,网络仍然可以正常运行并达到0.97的准确率。证明我自己搭的这个网络结构有很好的鲁棒性。

总结:

通过以上遇到问题可以总结出导致nan问题的原因有如下三个:

1.梯度爆炸:观察log,注意每一轮迭代后的loss。loss随着每轮迭代越来越大,最终超过了浮点型表示的范围,就变成了NaN。

2.不当的损失函数:有时候损失层中loss的计算可能导致NaN的出现。比如,给InfogainLoss层(信息熵损失)输入没有归一化的值,使用带有bug的自定义损失层等等。观测训练产生的log时一开始并不能看到异常,loss也在逐步的降低,但突然之间NaN就出现了

3.错误的输入:输入中就含有NaN,每当学习的过程中碰到这个错误的输入,就会变成NaN。观察log的时候也许不能察觉任何异常,loss逐步的降低,但突然间就变成NaN了。

												

神经网络训练时出现nan错误的更多相关文章

  1. caffe 训练时,出现错误:Check failed: error == cudaSuccess (4 vs. 0) unspecified launch failure

    I0415 15:03:37.603461 27311 solver.cpp:42] Solver scaffolding done.I0415 15:03:37.603549 27311 solve ...

  2. Keras 训练时出现 CUDA_ERROR_OUT_OF_MEMORY 错误

    不用惊慌,再试一次.估计当时GPU内存可分配不足,可手动结束所有python程序后释放相关GPU内存,或者重新运行一次终端

  3. 一文读懂神经网络训练中的Batch Size,Epoch,Iteration

    一文读懂神经网络训练中的Batch Size,Epoch,Iteration 作为在各种神经网络训练时都无法避免的几个名词,本文将全面解析他们的含义和关系. 1. Batch Size 释义:批大小, ...

  4. caffe︱深度学习参数调优杂记+caffe训练时的问题+dropout/batch Normalization

    一.深度学习中常用的调节参数 本节为笔者上课笔记(CDA深度学习实战课程第一期) 1.学习率 步长的选择:你走的距离长短,越短当然不会错过,但是耗时间.步长的选择比较麻烦.步长越小,越容易得到局部最优 ...

  5. 深度学习与CV教程(6) | 神经网络训练技巧 (上)

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

  6. 神经网络训练中的Tricks之高效BP(反向传播算法)

    神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...

  7. tesorflow - create neural network+结果可视化+加速神经网络训练+Optimizer+TensorFlow

    以下仅为了自己方便查看,绝大部分参考来源:莫烦Python,建议去看原博客 一.添加层 def add_layer() 定义 add_layer()函数 在 Tensorflow 里定义一个添加层的函 ...

  8. tensorflow-gpu版本安装及深度神经网络训练与cpu版本对比

    tensorflow1.0和tensorflow2.0的区别主要是1.0用的静态图 一般情况1.0已经足够,但是如果要进行深度神经网络的训练,当然还是tensorflow2.*-gpu比较快啦. 其中 ...

  9. Cs231n课堂内容记录-Lecture 7 神经网络训练2

    Lecture 7  Training Neural Networks 2 课堂笔记参见:https://zhuanlan.zhihu.com/p/21560667?refer=intelligent ...

随机推荐

  1. The Unreasonable Effectiveness of Recurrent Neural Networks (RNN)

    http://karpathy.github.io/2015/05/21/rnn-effectiveness/ There’s something magical about Recurrent Ne ...

  2. Cannot format given Object as a Date

    这个小错挺有意思的,记录一下 导出Excel的时候,同事直接用 format …… 前提:数据库中该字段是 Timestamp ---- 2016-06-20 22:49:02.967 写个测试说明一 ...

  3. java二叉排序树

    二叉排序树又称二叉查找树.它或者是一颗空树,或者是具有如下性质的二叉树: 1.如果左子树不空,那么左子树上的所有节点均小于它的根节点的值: 2.如果右子树不空,那么右子树上的所有节点均大于它的根节点的 ...

  4. Log4Net各参数API

    <?xml version="1.0" encoding="utf-8" ?> <configuration> <configSe ...

  5. c3p0连接不上sql server

    1.首先判断sqlserver是否安装成功 2如果sqlserver正常运行 可能是因为没有打开1433端口,按以下步骤设置tcp/ip enabled 打开之后再重新启动sql server服务,按 ...

  6. python之路 RabbitMQ、SQLAlchemy

    一.RabbitMQ RabbitMQ是一个在AMQP基础上完整的,可复用的企业消息系统.他遵循Mozilla Public License开源协议. MQ全称为Message Queue, 消息队列 ...

  7. 属性检测 In,hasOwnPreperty()和propertyIsEnumerable()

    IN  左侧是属性名:右侧是对象名, 如果 属性是 自有属性 或者继承属性 则返回 TRUE var o={x:1,y:2} "x" in  o    返回 true: hasOw ...

  8. Spark机器学习5·回归模型(pyspark)

    分类模型的预测目标是:类别编号 回归模型的预测目标是:实数变量 回归模型种类 线性模型 最小二乘回归模型 应用L2正则化时--岭回归(ridge regression) 应用L1正则化时--LASSO ...

  9. Shell 实现找出两个目录下的同名文件方法

    # 首先我们来创建一些 2 个目录,里面的目录结构及相关文件如下所示: # 从上面的测试目录可以看到, lol.txt lol2.txt 两个文件是两个目录下的同名文件 # 有实际例子,思路就容易出来 ...

  10. Django QuerySet API

    https://docs.djangoproject.com/en/2.1/ref/models/querysets/