1 前言

时域卷积网络(Temporal Convolutional Network,TCN)属于卷积神经网络(CNN)家族,于2017年被提出,目前已在多项时间序列数据任务中击败循环神经网络(RNN)家族。

TCN 网络结构

图中,xi 表示第 i 个时刻的特征,可以是多维的。

TCN源码见-->GitHub - philipperemy/keras-tcn: Keras Temporal Convolutional Network.,由于源码过于复杂,新手不易上手,笔者参照源码,手撕了个简洁版的TCN,与君共享。

本文以 MNIST 手写数字分类为例,讲解 TCN 模型。关于 MNIST 数据集的说明,见使用TensorFlow实现MNIST数据集分类

笔者工作空间如下:

代码资源见-->时域卷积网络(TCN)案例模型

2 实验

TCN.py

from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Conv1D,Activation,Flatten,Dense #载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images.reshape(-1,28,28),mnist.train.labels,
valid_x,valid_y=mnist.validation.images.reshape(-1,28,28),mnist.validation.labels,
test_x,test_y=mnist.test.images.reshape(-1,28,28),mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y #残差块
def ResBlock(x,filters,kernel_size,dilation_rate):
r=Conv1D(filters,kernel_size,padding='same',dilation_rate=dilation_rate,activation='relu')(x) #第一卷积
r=Conv1D(filters,kernel_size,padding='same',dilation_rate=dilation_rate)(r) #第二卷积
if x.shape[-1]==filters:
shortcut=x
else:
shortcut=Conv1D(filters,kernel_size,padding='same')(x) #shortcut(捷径)
o=add([r,shortcut])
o=Activation('relu')(o) #激活函数
return o #序列模型
def TCN(train_x,train_y,valid_x,valid_y,test_x,test_y):
inputs=Input(shape=(28,28))
x=ResBlock(inputs,filters=32,kernel_size=3,dilation_rate=1)
x=ResBlock(x,filters=32,kernel_size=3,dilation_rate=2)
x=ResBlock(x,filters=16,kernel_size=3,dilation_rate=4)
x=Flatten()(x)
x=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=x)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=30,verbose=2,validation_data=(valid_x,valid_y))
#评估模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1]) train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
TCN(train_x,train_y,valid_x,valid_y,test_x,test_y)

网络各层输出尺寸:

__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 28, 28) 0
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 28, 32) 2720 input_1[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 28, 32) 3104 conv1d_1[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 28, 32) 2720 input_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 28, 32) 0 conv1d_2[0][0]
conv1d_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 28, 32) 0 add_1[0][0]
__________________________________________________________________________________________________
conv1d_4 (Conv1D) (None, 28, 32) 3104 activation_1[0][0]
__________________________________________________________________________________________________
conv1d_5 (Conv1D) (None, 28, 32) 3104 conv1d_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 28, 32) 0 conv1d_5[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 28, 32) 0 add_2[0][0]
__________________________________________________________________________________________________
conv1d_6 (Conv1D) (None, 28, 16) 1552 activation_2[0][0]
__________________________________________________________________________________________________
conv1d_7 (Conv1D) (None, 28, 16) 784 conv1d_6[0][0]
__________________________________________________________________________________________________
conv1d_8 (Conv1D) (None, 28, 16) 1552 activation_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 28, 16) 0 conv1d_7[0][0]
conv1d_8[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 28, 16) 0 add_3[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 448) 0 activation_3[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 10) 4490 flatten_1[0][0]
==================================================================================================
Total params: 23,130
Trainable params: 23,130
Non-trainable params: 0

网络训练结果:

Epoch 28/30
- 6s - loss: 0.0112 - acc: 0.9966 - val_loss: 0.0539 - val_acc: 0.9854
Epoch 29/30
- 6s - loss: 0.0080 - acc: 0.9977 - val_loss: 0.0536 - val_acc: 0.9872
Epoch 30/30
- 6s - loss: 0.0099 - acc: 0.9965 - val_loss: 0.0486 - val_acc: 0.9892
test_loss: 0.055041389787220396 - test_acc: 0.9855000048875808

可以看到,TCN模型的预测精度为 0.9855, 超越了 seq2seq模型案例分析 中 AttSeq2Seq 模型(0.9825)、基于keras的双层LSTM网络和双向LSTM网络 中 DoubleLSTM 模型(0.9789)和 BiLSTM 模型(0.9795)、基于keras的残差网络 中 ResNet 模型(0.9721)。

3 拓展延申

有时候,并不需要最后一层 TCN 输出序列的所有步,而只需要最后一层 TCN 输出序列的第一步或最后一步。这时候,需要借助 lambda 关键字定义 Lambda 层,取代 Flatten 层。如下:

from keras.layers import Lambda
......
x=ResBlock(x,filters=16,kernel_size=3,dilation_rate=4)
x=Lambda(lambda x: x[:,0,:])(x) #此前是:x=Flatten()(x)
x=Dense(10,activation='softmax')(x)
......

lambda 关键字用于定义匿名函数,应用如下:

import numpy as np
f=lambda x: x*x+x+1
x=np.array([1,2,3])
y=f(x)
print(y) #输出:[ 3 7 13]

​ 声明:本文转自基于keras的时域卷积网络(TCN)

基于keras的时域卷积网络(TCN)的更多相关文章

  1. R-FCN:基于区域的全卷积网络来检测物体

    http://blog.csdn.net/shadow_guo/article/details/51767036 原文标题为“R-FCN: Object Detection via Region-ba ...

  2. 时空卷积网络TCN

    1.写在前面 实验表明,RNN 在几乎所有的序列问题上都有良好表现,包括语音/文本识别.机器翻译.手写体识别.序列数据分析(预测)等. 在实际应用中,RNN 在内部设计上存在一个严重的问题:由于网络一 ...

  3. 用keras作CNN卷积网络书本分类(书本、非书本)

    本文介绍如何使用keras作图片分类(2分类与多分类,其实就一个参数的区别...呵呵) 先来看看解决的问题:从一堆图片中分出是不是书本,也就是最终给图片标签上:“书本“.“非书本”,简单吧. 先来看看 ...

  4. 基于孪生卷积网络(Siamese CNN)和短时约束度量联合学习的tracklet association方法

    基于孪生卷积网络(Siamese CNN)和短时约束度量联合学习的tracklet association方法 Siamese CNN Temporally Constrained Metrics T ...

  5. keras搭建密集连接网络/卷积网络/循环网络

    输入模式与网络架构间的对应关系: 向量数据:密集连接网络(Dense层) 图像数据:二维卷积神经网络 声音数据(比如波形):一维卷积神经网络(首选)或循环神经网络 文本数据:一维卷积神经网络(首选)或 ...

  6. 基于 Keras 用 LSTM 网络做时间序列预测

    目录 基于 Keras 用 LSTM 网络做时间序列预测 问题描述 长短记忆网络 LSTM 网络回归 LSTM 网络回归结合窗口法 基于时间步的 LSTM 网络回归 在批量训练之间保持 LSTM 的记 ...

  7. 基于TensorFlow解决手写数字识别的Softmax方法、多层卷积网络方法和前馈神经网络方法

    一.基于TensorFlow的softmax回归模型解决手写字母识别问题 详细步骤如下: 1.加载MNIST数据: input_data.read_data_sets('MNIST_data',one ...

  8. TCN时间卷积网络——解决LSTM的并发问题

    TCN是指时间卷积网络,一种新型的可以用来解决时间序列预测的算法.在这一两年中已有多篇论文提出,但是普遍认为下篇论文是TCN的开端. 论文名称: An Empirical Evaluation of ...

  9. 【GCN】图卷积网络初探——基于图(Graph)的傅里叶变换和卷积

    [GCN]图卷积网络初探——基于图(Graph)的傅里叶变换和卷积 2018年11月29日 11:50:38 夏至夏至520 阅读数 5980更多 分类专栏: # MachineLearning   ...

  10. OverFeat:基于卷积网络的集成识别、定位与检测

    摘要:我们提出了一个使用卷积网络进行分类.定位和检测的集成框架.我们展示了如何在ConvNet中有效地实现多尺度和滑动窗口方法.我们还介绍了一种新的深度学习方法,通过学习预测对象边界来定位.然后通过边 ...

随机推荐

  1. 【Gui-Guider】安装后运行模拟器报 JAVA 错误

    运行模拟器出错 上述错误是因为需要JAVA环境 JAVA 环境下载网址 https://www.oracle.com/java/technologies/javase-jdk16-downloads. ...

  2. [转帖]Oracle如何重启mmon/mmnl进程(AWR自动采集)

    https://www.cnblogs.com/jyzhao/p/10119854.html 学习一下 环境:Oracle 11.2.0.4 RAC现象:sysaux空间满导致无法正常生成快照,清理空 ...

  3. [转帖]Kafka查看topic、consumer group状态命令

    https://www.cnblogs.com/AcAc-t/p/kafka_topic_consumer_group_command.html 最近工作中遇到需要使用kafka的场景,测试消费程序启 ...

  4. [转帖]TIDB-TIDB节点磁盘已满报警

    一.背景 今日突然收到tidb节点的磁盘报警,磁盘容量已经超过了80%,但是tidb是不放数据的,磁盘怎么会满,这里就需要排查了 二.问题排查 解决步骤 1.df -h查看哪里占用磁盘比较多,然后通过 ...

  5. [转帖]harbor镜像仓库清理操作

    https://www.cnblogs.com/FengGeBlog/p/15517706.html 两年前清理过一次harbor镜像,而现在又要面临清镜像的操作了,笔者目前所在的公司镜像是存放在ce ...

  6. [转帖]【压测】通过Jemeter进行压力测试(超详细)

    文章目录 背景 一.前言 二.关于JMeter 三.准备工作 四.创建测试 4.1.创建线程组 4.2.配置元件 4.3.构造HTTP请求 4.4.添加HTTP请求头 4.5.添加断言 4.6.添加察 ...

  7. 多个物理磁盘挂载到同一目录的方法 (lvm 软raid)

    多个物理磁盘挂载到同一目录的方法 (lvm 软raid) 背景 公司里面的一台申威3231的机器 因为这个机器的raid卡没有操作界面. 所以只能够通过命令行方式创建raid 自己这一块比较菜, 想着 ...

  8. K3S +Helm+NFS最小化测试安装部署只需十分钟

    作者:郝建伟 k3s 简介 官方文档:k3s 什么是k3s k3s 是一个轻量级的 Kubernetes 发行版 它针对边缘计算.物联网等场景进行了高度优化. k3s 有以下增强功能: 打包为单个二进 ...

  9. 【解决一个小问题】golang 的 `-race`选项导致 unsafe代码 panic

    作者:张富春(ahfuzhang),转载时请注明作者和引用链接,谢谢! cnblogs博客 zhihu Github 公众号:一本正经的瞎扯 为了提升性能,使用 unsafe 代码来重构了凯撒加密的代 ...

  10. Unity中的string gc优化

    在项目中如果有大量的字符串拼接,比如每秒执行的倒计时,协议中的日志输出,每次拼接会产生大量的gc,尤其是在ILRuntime下执行 gc alloc的次数会更加频繁. zstring 有两个字符串处理 ...