这篇文章是小萌新对西储大学轴承故障进行分析,固定特征为故障直径为0.007,电机转速为1797,12k驱动端故障数据(Drive_End)即DE-time。故障类型y值:滚动体故障,内圈故障,3时,6时,12时外圈故障。

  由于CWRU包无法在python3中直接用,因此首先改写了cwru的代码,直接进行数据处理并划分为训练集和测试集。接着对这些数据进行HHT转换,求出imfs,最后再构建LSTM模型进行分析。

  1. import PyEMD
  2. from PyEMD import *
  3. import scipy
  4. from scipy.io import loadmat
  5. import matplotlib.pyplot as plt
  6. from matplotlib import *
  7. import os
  8. import errno
  9. import urllib.request as urllib
  10. import numpy as np
  11.  
  12. from scipy.io import loadmat
  13.  
  14. import random
  15.  
  16. import pandas as pd
  17.  
  18. from keras.callbacks import ModelCheckpoint
  19.  
  20. from keras.models import Model, load_model, Sequential
  21.  
  22. from keras.layers import Dense, Activation, Dropout, Input, Masking, TimeDistributed, LSTM, Conv1D, Flatten
  23.  
  24. from keras.layers import GRU, Bidirectional, BatchNormalization, Reshape
  25.  
  26. from keras.optimizers import Adam
  27.  
  28. from keras.utils import to_categorical
  29.  
  30. from keras.utils import plot_model
  31.  
  32. '''选取训练集和测试集数据
  33. #选取故障直径为0.007,电机转速为1797,12k驱动端故障数据(Drive_End),
  34. #y值分别为滚动体故障,内圈故障,3时,6时,12时外圈故障'''
  35. ## ===================================================选取数据===============================================================
  36. #
  37. class CWRU:
  38.  
  39. def __init__(self, path1,length):
  40. file_list = []
  41. for root,dirs,files in os.walk(path1):
  42. for file in files:
  43. if '12k_Drive_End' in file and '007' in file and '_0_' in file:
  44. file_list.append(file)
  45.  
  46. self.length = length
  47. self._load_and_slice_data(path1, file_list)
  48.  
  49. # shuffle training and test arrays
  50.  
  51. def _load_and_slice_data(self, rdir, infos):
  52. self.X_train = np.zeros((0, self.length))
  53. self.X_test = np.zeros((0, self.length))
  54. self.y_train = []
  55. self.y_test = []
  56. for idx, info in enumerate(infos):
  57. # # directory of this file
  58. fdir = os.path.join(rdir, info)
  59. mat_dict = loadmat(fdir) #载入数据
  60. fliter_i = filter(lambda x: 'DE_time' in x, mat_dict.keys()) #提取数据中的de-time部分
  61. fliter_list = [item for item in fliter_i]
  62. key = fliter_list[0] #这两步是取key值
  63. # key = filter(lambda x: 'DE_time' in x, mat_dict.keys())[0]
  64. time_series1 = mat_dict[key][:, 0] #将DE-time的时间序列取出来
  65. time_series = time_series1[:120001]
  66. idx_last = -(time_series.shape[0] % self.length) #算出信号长度整数倍外还有那些数
  67. clips = time_series[:idx_last].reshape(-1, self.length) # 将提取的时间序列转换成二维,每一个数据的长度为设置的长度
  68.  
  69. n = clips.shape[0] #行数,也就是代表数据量的大小
  70. n_split =int((3 * n / 4)) #设置训练集和测试集的比例
  71. self.X_train = np.vstack((self.X_train, clips[:n_split])) #取训练集
  72. self.X_test = np.vstack((self.X_test, clips[n_split:])) #取测试集
  73. self.y_train += [idx] * n_split #给故障类型设立标签
  74. self.y_test += [idx] * (clips.shape[0] - n_split) #给测试的故障类型设立标签
  75.  
  76. path1 = r"E:\work\CWRU_analysis\CaseWesternReserveUniversityData-master"
  77.  
  78. data = CWRU(path1, 400)
  79.  
  80. X_train,y_train, X_test,y_test = [],[],[],[]
  81.  
  82. X_test.extend(data.X_test)
  83.  
  84. y_test.extend(data.y_test)
  85.  
  86. X_train.extend(data.X_train)
  87.  
  88. y_train.extend(data.y_train)
  89.  
  90. ''' ===============================================将data进行HHT,求出imf,这会转成三维的数据============================================================================
  91. #对每个 数据去求imf,并作为输入'''
  92. def data_to_imf(Data,t):
  93. imf = []
  94. for data in Data:
  95. emd = EMD()
  96. imf_ = emd.emd(data,t)[:5]
  97. imf.append(imf_)
  98. return np.array(imf).reshape(-1,5,400)
  99.  
  100. t = np.linspace(0, 1, 12000)[:400]
  101.  
  102. a = data_to_imf(X_train,t)
  103.  
  104. X_train_data = np.transpose(a,(0,2,1))
  105.  
  106. y_train_data = to_categorical(y_train) #将数据转换成类别矩阵
  107.  
  108. b= data_to_imf(X_test,t)
  109.  
  110. X_test_data = np.transpose(b,(0,2,1))
  111.  
  112. y_test_data = to_categorical(y_test)
  113.  
  114. '''
  115. # =======================================构建LSTM模型并实验========================================================================================
  116. # '''
  117.  
  118. def create_model():
  119. model = Sequential()
  120. #输入数据的shape为(n_samples, timestamps, features)
  121. #隐藏层设置为20, input_shape元组第二个参数1意指features为1
  122. model.add(LSTM(units=20,input_shape=(X_train_data.shape[1], X_train_data.shape[2])))
  123. # model.add(Dropout(0.2))
  124. #后接全连接层,直接输出单个值,故units为10
  125.  
  126. model.add(Dense(units=5))
  127. model.add(Activation('softmax'))#选用非线性激活函数,用于分类
  128. model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.001), metrics=['accuracy'])#损失函数为平均均方误差,优化器为Adam,学习率为0.001
  129. return model
  130.  
  131. model = create_model()
  132. history =model.fit(X_train_data, y_train_data, epochs=1000, batch_size=225)
  133.  
  134. #求损失函数
  135. loss, acc = model.evaluate(X_test_data, y_test_data)
  136.  
  137. #保存模型
  138. model_save_path = "model_file_path.h5"
  139. model.save(model_save_path)
  140. model.summary()
  141.  
  142. #输出准确率
  143. print("Dev set accuracy = ", acc)

  本文有参考众多故障分析文章,但是忘了保存,没有链接了。。。。。。。。

基于LSTM对西储大学轴承故障进行分析的更多相关文章

  1. 基于SCADA数据驱动的风电机组部件故障预警

    吴亚联 1 , 梁坤鑫 1 , 苏永新 1* , 詹 俊 2(1.湘潭大学 信息工程学院, 湖南 湘潭 411105: 2.湖南优利泰克自动化系统有限公司, 湖南 长沙 410205) 摘 要: 为提 ...

  2. 【爆料】-《西悉尼大学毕业证书》UWS一模一样原件

    ☞西悉尼大学毕业证书[微/Q:865121257◆WeChat:CC6669834]UC毕业证书/联系人Alice[查看点击百度快照查看][留信网学历认证&博士&硕士&海归&a ...

  3. tensorflow实现基于LSTM的文本分类方法

    tensorflow实现基于LSTM的文本分类方法 作者:u010223750 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实 ...

  4. 在TensorFlow中基于lstm构建分词系统笔记

    在TensorFlow中基于lstm构建分词系统笔记(一) https://www.jianshu.com/p/ccb805b9f014 前言 我打算基于lstm构建一个分词系统,通过这个例子来学习下 ...

  5. 基于LSTM + keras 的诗歌生成器

        最近在github 上发现了一个好玩的项目,一个基于LSTM + keras 实现的诗歌生成器,地址是:https://github.com/youyuge34/Poems_generator ...

  6. 深度学习|基于LSTM网络的黄金期货价格预测--转载

    深度学习|基于LSTM网络的黄金期货价格预测 前些天看到一位大佬的深度学习的推文,内容很适用于实战,争得原作者转载同意后,转发给大家.之后会介绍LSTM的理论知识. 我把code先放在我github上 ...

  7. 基于Spark和SparkSQL的NetFlow流量的初步分析——scala语言

    基于Spark和SparkSQL的NetFlow流量的初步分析--scala语言 标签: NetFlow Spark SparkSQL 本文主要是介绍如何使用Spark做一些简单的NetFlow数据的 ...

  8. 基于NetMQ的TLS框架NetMQ.Security的实现分析

    基于NetMQ的TLS框架NetMQ.Security的实现分析 前言 介绍 交互过程 支持的协议 TLS协议 支持的算法 实现 握手 第一次握手 Client Hello 第二次握手 Server ...

  9. 基于UML的中职班主任工作管理系统的分析与设计--文献随笔(二)

    一.基本信息 标题:基于UML的中职班主任工作管理系统的分析与设计 时间:2016 出版源:遵义航天工业学校 关键字:中职学校; 班主任工作管理; UML建模 二.研究背景 问题定义:班主任是一项特殊 ...

随机推荐

  1. Django 补充知识

    目录 Django基于配置文件的编程思想 初步实现 大佬实现 跨站请求伪造csrf 什么是csrf? 前端如何解决 ajax解决 csrf相关的装饰器 FBV方式装饰器 CVB方式装饰器 Django ...

  2. kafka数据分区的四种策略

    kafka的数据的分区 探究的是kafka的数据生产出来之后究竟落到了哪一个分区里面去了 第一种分区策略:给定了分区号,直接将数据发送到指定的分区里面去 第二种分区策略:没有给定分区号,给定数据的ke ...

  3. LUOGU P4777 【模板】扩展中国剩余定理(EXCRT)

    传送门 解题思路 扩展 $crt​$,就是中国剩余定理在模数不互质的情况下,首先对于方程 ​     $\begin{cases} x\equiv a_1\mod m_1\\x\equiv a_2\m ...

  4. 莫烦pytorch学习笔记(二)——variable

    .简介 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Variable和tensor的区别和联系 Variable是篮子, ...

  5. pixi.js 学习

     事件(event):PIXI库在精灵和舞台上提供了事件,用于交互. stage.click = function(data){ var event = data.originalEvent } sp ...

  6. VUE下echarts宽度响应式

    window.addEventListener("resize", () => { myChart2.resize();});

  7. 深入理解JVM之类加载

    ---title: [学习]深入理解JVM之类加载.mddate: 2019-10-20 22:20:06tags: JVM 类加载--- Java类的加载,连接,初始化都是在程序运行期间执行的 ## ...

  8. Maven入门指南

    Maven入门指南 本指南旨在第一次为使用Maven的人员提供参考,但也打算作为一本包含公共用例的独立参考和解决方案的工具书.对于新用户,建议您按顺序浏览该材料.对于更熟悉Maven的用户,本指南致力 ...

  9. java基础之完数判断

    完数: 完全数(Perfect number),又称完美数或完备数,是一些特殊的自然数.它所有的真因子(即除了自身以外的约数)的和(即因子函数),恰好等于它本身.如果一个数恰好等于它的因子之和,则称该 ...

  10. 我学习python没有记住的东西

    格式化 # 格式化 a=123 b='ww' print("%d,%s,%%"%(a,b)) # %d,%s,%f,%c,%f 格式化代码:print('{}{}'.format( ...