传感器数据集

这个项目使用了 WISDM (Wireless Sensor Data Mining) Lab 实验室公开的 Actitracker 的数据集。 WISDM 公开了两个数据集,一个是在实验室环境采集的;另一个是在真实使用场景中采集的,这里使用的是实验室环境采集的数据。

  • 测试记录:1,098,207 条
  • 测试人数:36 人
  • 采样频率:20 Hz
  • 行为类型:6 种
    • 走路
    • 慢跑
    • 上楼梯
    • 下楼梯
    • 站立
  • 传感器类型:加速度
  • 测试场景:手机放在衣兜里面

数据分析

从 实验室采集数据下载地址 下载数据集压缩包,解压后可以看到下面这些文件:

  • readme.txt
  • WISDM_ar_v1.1_raw_about.txt
  • WISDM_ar_v1.1_trans_about.txt
  • WISDM_ar_v1.1_raw.txt
  • WISDM_ar_v1.1_transformed.arff

我们需要的是包含 RAW 数据的 WISDM_ar_v1.1_raw.txt 文件,其他的是转换后的或者说明文件。文件的每一行数据包含了测试id,行为类型,时间,加速度x,加速度y和加速度z。

这是一个不平衡的数据集,各个行为类型的数据比例不同,官方的描述文件中也提到:

  • Walking: 424,400 (38.6%)
  • Jogging: 342,177 (31.2%)
  • Upstairs: 122,869 (11.2%)
  • Downstairs: 100,427 (9.1%)
  • Sitting: 59,939 (5.5%)
  • Standing: 48,395 (4.4%)

创建、训练并测试模型

模型可以分为数据输入模型,训练模型创建和测试评估模型。

数据输入类型可以分为数据集加载,数据标准化,创建输入数据。

加载数据集

文件的数据保存在 WISDM_ar_v1.1_raw.txt中,可以用pandas里面的函数,read_csv()进行读取。

def read_data(file_path):
column_names = ['user-id', 'activity', 'timestamp', 'x-axis', 'y-axis', 'z-axis']
data = pd.read_csv(file_path, header=None, names=column_names)
return data

数据标准化

对变量的标准差标准化:标准差标准化是将某变量中的观察值减去该变量的平均数,然后除以该变量的标准差。即

x’ik = (xik -mean )/sk

经过标准差标准化后,各变量将有约一半观察值的数值小于0,另一半观察值的数值大于0,变量的平均数为0,标准差为1。经标准化的数据都是没有单位的纯数量。对变量进行的标准差标准化可以消除量纲(单位)影响和变量自身变异的影响。但有人认为经过这种标准化后,原来数值较大的的观察值对分类结果的影响仍然占明显的优势,应该进一步消除大小因子的影响。尽管如此,它还是当前用得最多的数据标准化方法。

# 数据标准化
def feature_normalize(dataset):
mu = np.mean(dataset, axis=0)
sigma = np.std(dataset, axis=0)
return (dataset - mu) / sigma

创建输入数据

创建输入数据,每一组数据包含x,y,z三个轴的90条连续记录,用‘stats.mode’方法获取这90条记录中出现次数最多的行为作为该组行为的标签,

def segment_signal(data, window_size=90):
segments = np.empty((0, window_size, 3))
labels = np.empty((0))
print len(data['timestamp'])
count = 0
for (start, end) in windows(data['timestamp'], window_size):
print count
count += 1
x = data["x-axis"][start:end]
y = data["y-axis"][start:end]
z = data["z-axis"][start:end]
if (len(dataset['timestamp'][start:end]) == window_size):
segments = np.vstack([segments, np.dstack([x, y, z])])
labels = np.append(labels, stats.mode(data["activity"][start:end])[0][0])
return segments, labels

训练模型创建(cnn的结构如下):

  • 输入:1*90大小的向量,3通道.(每天轴为一个通道,90为连续的90条记录)
  • 第一层卷积:1*10大小的卷积核60个。
  • 第一层max-pooling:20*1的核。
  • 第二层卷积:1*6卷积核10个。
  • 第一层全连接:有1000个隐藏的神经元,采用tanh的激活函数
  • Softmax层:归一化。

训练过程采用的是梯度下降算法。

卷积函数和池化函数的定义:

def depthwise_conv2d(x, W):
return tf.nn.depthwise_conv2d(x, W, [1, 1, 1, 1], padding='VALID') # 为输入数据的每个 channel 执行一维卷积,并输出到 ReLU 激活函数
def apply_depthwise_conv(x, kernel_size, num_channels, depth):
weights = weight_variable([1, kernel_size, num_channels, depth])
biases = bias_variable([depth * num_channels])
return tf.nn.relu(tf.add(depthwise_conv2d(x, weights), biases)) # 在卷积层输出进行一维 max pooling
def apply_max_pool(x, kernel_size, stride_size):
return tf.nn.max_pool(x, ksize=[1, 1, kernel_size, 1],
strides=[1, 1, stride_size, 1], padding='VALID')

卷积网络的建立:

# 下面是使用 Tensorflow 创建神经网络的过程。
X = tf.placeholder(tf.float32, shape=[None,input_height,input_width,num_channels])
Y = tf.placeholder(tf.float32, shape=[None,num_labels]) c = apply_depthwise_conv(X,kernel_size,num_channels,depth)
p = apply_max_pool(c,20,2)
c = apply_depthwise_conv(p,6,depth*num_channels,depth//10) shape = c.get_shape().as_list()
c_flat = tf.reshape(c, [-1, shape[1] * shape[2] * shape[3]]) f_weights_l1 = weight_variable([shape[1] * shape[2] * depth * num_channels * (depth//10), num_hidden])
f_biases_l1 = bias_variable([num_hidden])
f = tf.nn.tanh(tf.add(tf.matmul(c_flat, f_weights_l1),f_biases_l1)) out_weights = weight_variable([num_hidden, num_labels])
out_biases = bias_variable([num_labels])
y_ = tf.nn.softmax(tf.matmul(f, out_weights) + out_biases)

损失函数集训练函数

loss = -tf.reduce_sum(Y * tf.log(y_))
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(loss)

测试评估模型

测试评估函数

correct_prediction = tf.equal(tf.argmax(y_,1), tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

我们将所有的数据进行7/3分,7用于对模型进行训练,3用于对模型精度进行测试,其结果如下:

Precision 0.888409996548

Recall 0.884568153909

f1_score 0.880684696544

可以看成在在只迭代了8次的情况之下,等到了0.88的精度,效果非常好。

基于手机传感器数据使用 CNN 识别用户行为的 Tensroflow 实现的更多相关文章

  1. 【AllJoyn专题】基于AllJoyn和Yeelink的传感器数据上传与指令下行的研究

    接触高通物联网框架AllJoyn不太久,但确是被深深地吸引了.在我看来,促进我深入学习的原因有三点:一.AllJoyn开源,对开源的软硬件总会有种莫名的喜爱,虽然或许不会都深入下去:二.顺应潮流,物联 ...

  2. 客户流失?来看看大厂如何基于spark+机器学习构建千万数据规模上的用户留存模型 ⛵

    作者:韩信子@ShowMeAI 大数据技术 ◉ 技能提升系列:https://www.showmeai.tech/tutorials/84 行业名企应用系列:https://www.showmeai. ...

  3. python 手机App数据抓取实战二抖音用户的抓取

    前言 什么?你问我国庆七天假期干了什么?说出来你可能不信,我爬取了cxk坤坤的抖音粉丝数据,我也不知道我为什么这么无聊. 本文主要记录如何使用appium自动化工具实现抖音App模拟滑动,然后分析数据 ...

  4. 基于uFUN开发板的心率计(一)DMA方式获取传感器数据

    前言 从3月8号收到板子,到今天算起来,uFUN到手也有两周的时间了,最近利用下班后的时间,做了个心率计,从单片机程序到上位机开发,到现在为止完成的差不多了,实现很简单,uFUN开发板外加一个Puls ...

  5. OneAPM NI 基于旁路镜像数据的真实用户体验监控

    在这个应用无处不在的时代,一次网络购物,一次网络银行交易,一次网络保险的购买,一次春运车票的购买,一次重要工作邮件的收发中出现的延时,卡顿对企业都可能意味着用户忠诚度下降,真金白银的损失. 因而感知真 ...

  6. 基于LeNet网络的中文验证码识别

    基于LeNet网络的中文验证码识别 由于公司需要进行了中文验证码的图片识别开发,最近一段时间刚忙完上线,好不容易闲下来就继上篇<基于Windows10 x64+visual Studio2013 ...

  7. 基于深度学习的人脸性别识别系统(含UI界面,Python代码)

    摘要:人脸性别识别是人脸识别领域的一个热门方向,本文详细介绍基于深度学习的人脸性别识别系统,在介绍算法原理的同时,给出Python的实现代码以及PyQt的UI界面.在界面中可以选择人脸图片.视频进行检 ...

  8. 【基于WPF+OneNote+Oracle的中文图片识别系统阶段总结】之篇二:基于OneNote难点突破和批量识别

    篇一:WPF常用知识以及本项目设计总结:http://www.cnblogs.com/baiboy/p/wpf.html 篇二:基于OneNote难点突破和批量识别:http://www.cnblog ...

  9. 转:基于开源项目OpenCV的人脸识别Demo版整理(不仅可以识别人脸,还可以识别眼睛鼻子嘴等)【模式识别中的翘楚】

    文章来自于:http://blog.renren.com/share/246648717/8171467499 基于开源项目OpenCV的人脸识别Demo版整理(不仅可以识别人脸,还可以识别眼睛鼻子嘴 ...

随机推荐

  1. 第二部分 职责型模式responsibility

    普通职责无法提供的内容,据此可以定义以下几种模式: 将职责集中到某个类的一个单独实例,单件模式 当一个对象发生改变时,依赖于这个对象的其他对象都能够得到通知,而这个发生改变的对象无须了解自己被其他哪些 ...

  2. C: strcpy & memcpy & scanf/printf format specifier.. escape characters..

    well, strcpy differs from memcpy in that it stops copy at \0 the format specifier is a string.. whic ...

  3. 【报错】java.lang.RuntimeException: Invalid action class configuration that references an unknown class named [xxxAction]

    java.lang.RuntimeException: Invalid action class configuration that references an unknown class name ...

  4. Android:OpenFire 相关API (持续更新)

    基于XMPP协议的聊天服务器.最近会一直更新相关的API. 需要的软件:OpenFire(服务器),Spark(客户端--测试用),Asmack(Jar包) 1.连接服务器的代码 private vo ...

  5. 《C++ Primer》之面向对象编程(二)

    构造函数和复制控制 每个派生类对象由派生类中定义的(非 static)成员加上一个或多个基类子对象构成,当我们构造.复制.赋值和撤销一个派生类对象时,也会构造.复制.赋值和撤销这些基类子对象. 构造函 ...

  6. List-----Array

    1.Definition Arry数组是一种连续储存的List 储存方式:将线性表中的元素一次储存在连续的储存空间中. Computer's logical structure: 逻辑位置上相邻的元素 ...

  7. 更改Xcode的缺省公司名

    更改前: //  testAppDelegate.m //  test // //  Created by gaohf on 11-5-24. //  Copyright 2011 __MyCompa ...

  8. mongo细节

    mongo创建表db.createCollection(name, {capped: <Boolean>, autoIndexId: <Boolean>, size: < ...

  9. postgres-xl 集体搭建

    pgxl 集群搭建 一 预备 1 下载安装解压源码 /opt/ curl -O http://files.postgres-xl.org/postgres-xl95r1beta1.tar.gz tar ...

  10. Javascript调用 ActiveXObject导出excel文档。

    function makeDataBook(){ var xls = new ActiveXObject ("Excel.Application"); xls.visible = ...