用CapsNets做电能质量扰动分类(2019-08-05)
当下最热神经网络为CNN,2017年10月,深度学习之父Hinton发表《胶囊间的动态路由》(Capsule Networks),最近谷歌正式开源了Hinton胶囊理论代码,提出的胶囊神经网络。本文不涉及原理,只是站在巨人的肩膀人,尝试把胶囊网络应用与分类问题。
原理和代码的参考文献是:https://blog.csdn.net/weixin_40920290/article/details/82951826
其中,本文采用的数据集和以2019年3月CNN做电能质量分类的一样,可以去那个博文中下载数据集。这里只展示代码。需要提醒的是,Capsule Networks的运行速度会比较慢,耐心等待.
如果由于格式问题无法运行,可以把邮箱私戳发给我,我把Capsule.py发给你
1.代码
from __future__ import print_function import numpy as np from keras import layers, models, optimizers from keras import backend as K from keras.utils import to_categorical import matplotlib.pyplot as plt from utils import combine_images from PIL import Image from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask import keras from pandas import read_csv K.set_image_data_format('channels_last') def CapsNet(input_shape, n_class, routings): """ A Capsule Network on MNIST. :param input_shape: data shape, 3d, [width, height, channels] :param n_class: number of classes :param routings: number of routing iterations :return: Two Keras Models, the first one used for training, and the second one for evaluation. `eval_model` can also be used for training. """ x = layers.Input(shape=input_shape) # Layer 1: Just a conventional Conv2D layer conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x) # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule] primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid') # Layer 3: Capsule layer. Routing algorithm works here. digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings, name='digitcaps')(primarycaps) # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. # If using tensorflow, this will not be necessary. :) out_caps = Length(name='capsnet')(digitcaps) # Decoder network. y = layers.Input(shape=(n_class,)) masked_by_y = Mask()([digitcaps, y]) # The true label is used to mask the output of capsule layer. For training masked = Mask()(digitcaps) # Mask using the capsule with maximal length. For prediction # Shared Decoder model in training and prediction decoder = models.Sequential(name='decoder') decoder.add(layers.Dense(512, activation='relu', input_dim=16*n_class)) decoder.add(layers.Dense(1024, activation='relu')) decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid')) decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon')) # Models for training and evaluation (prediction) train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)]) eval_model = models.Model(x, [out_caps, decoder(masked)]) # manipulate model noise = layers.Input(shape=(n_class, 16)) noised_digitcaps = layers.Add()([digitcaps, noise]) masked_noised_y = Mask()([noised_digitcaps, y]) manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y)) return train_model, eval_model, manipulate_model def margin_loss(y_true, y_pred): """ Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it. :param y_true: [None, n_classes] :param y_pred: [None, num_capsule] :return: a scalar loss value. """ L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \ 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1)) return K.mean(K.sum(L, 1)) def train(model, data, args): """ Training a CapsuleNet :param model: the CapsuleNet model :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))` :param args: arguments :return: The trained model """ # unpacking the data (x_train, y_train), (x_test, y_test) = data # callbacks log = callbacks.CSVLogger(args.save_dir + '/log.csv') tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs', batch_size=args.batch_size, histogram_freq=int(args.debug)) checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc', save_best_only=True, save_weights_only=True, verbose=1) lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch)) # compile the model model.compile(optimizer=optimizers.Adam(lr=args.lr), loss=[margin_loss, 'mse'], loss_weights=[1., args.lam_recon], metrics={'capsnet': 'accuracy'}) """ # Training without data augmentation: model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs, validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay]) """ # Begin: Training with data augmentation ---------------------------------------------------------------------# def train_generator(x, y, batch_size, shift_fraction=0.): train_datagen = ImageDataGenerator(width_shift_range=shift_fraction, height_shift_range=shift_fraction) # shift up to 2 pixel for MNIST generator = train_datagen.flow(x, y, batch_size=batch_size) while 1: x_batch, y_batch = generator.next() yield ([x_batch, y_batch], [y_batch, x_batch]) # Training with data augmentation. If shift_fraction=0., also no augmentation. model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction), steps_per_epoch=int(y_train.shape[0] / args.batch_size), epochs=args.epochs, validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay]) # End: Training with data augmentation -----------------------------------------------------------------------# model.save_weights(args.save_dir + '/trained_model.h5') print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir) from utils import plot_log plot_log(args.save_dir + '/log.csv', show=True) return model def test(model, data, args): x_test, y_test = data y_pred, x_recon = model.predict(x_test, batch_size=100) print('-'*30 + 'Begin: test' + '-'*30) print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0]) img = combine_images(np.concatenate([x_test[:50],x_recon[:50]])) image = img * 255 Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png") print() print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir) print('-' * 30 + 'End: test' + '-' * 30) plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png")) plt.show() def manipulate_latent(model, data, args): print('-'*30 + 'Begin: manipulate' + '-'*30) x_test, y_test = data index = np.argmax(y_test, 1) == args.digit number = np.random.randint(low=0, high=sum(index) - 1) x, y = x_test[index][number], y_test[index][number] x, y = np.expand_dims(x, 0), np.expand_dims(y, 0) noise = np.zeros([1, 10, 16]) x_recons = [] for dim in range(16): for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]: tmp = np.copy(noise) tmp[:,:,dim] = r x_recon = model.predict([x, y, tmp]) x_recons.append(x_recon) x_recons = np.concatenate(x_recons) img = combine_images(x_recons, height=16) image = img*255 Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit) print('manipulated result saved to %s/manipulate-%d.png' % (args.save_dir, args.digit)) print('-' * 30 + 'End: manipulate' + '-' * 30) def load_mnist(): # the data, shuffled and split between train and test sets dataset = read_csv('ZerosOnePowerQuality.csv') values = dataset.values XY= values num_classes = 8 Y = XY[:,784] n_train_hours1 =9000 x_train=XY[:n_train_hours1,0:784] trainY =Y[:n_train_hours1] x_test =XY[n_train_hours1:, 0:784] testY =Y[n_train_hours1:] x_train = x_train.reshape(-1,28,28,1) x_test = x_test.reshape(-1,28,28,1) y_train = keras.utils.to_categorical(trainY, num_classes) y_test = keras.utils.to_categorical(testY, num_classes) return (x_train, y_train), (x_test, y_test) if __name__ == "__main__": import os import argparse from keras.preprocessing.image import ImageDataGenerator from keras import callbacks # setting the hyper parameters parser = argparse.ArgumentParser(description="Capsule Network on MNIST.") parser.add_argument('--epochs', default=50, type=int) parser.add_argument('--batch_size', default=100, type=int) parser.add_argument('--lr', default=0.001, type=float, help="Initial learning rate") parser.add_argument('--lr_decay', default=0.9, type=float, help="The value multiplied by lr at each epoch. Set a larger value for larger epochs") parser.add_argument('--lam_recon', default=0.392, type=float, help="The coefficient for the loss of decoder") parser.add_argument('-r', '--routings', default=3, type=int, help="Number of iterations used in routing algorithm. should > 0") parser.add_argument('--shift_fraction', default=0.1, type=float, help="Fraction of pixels to shift at most in each direction.") parser.add_argument('--debug', action='store_true', help="Save weights by TensorBoard") parser.add_argument('--save_dir', default='./result') parser.add_argument('-t', '--testing', action='store_true', help="Test the trained model on testing dataset") parser.add_argument('--digit', default=5, type=int, help="Digit to manipulate") parser.add_argument('-w', '--weights', default=None, help="The path of the saved weights. Should be specified when testing") args = parser.parse_args() print(args) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # load data (x_train, y_train), (x_test, y_test) = load_mnist() # define model model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:], n_class=len(np.unique(np.argmax(y_train, 1))), routings=args.routings) model.summary() # train or test if args.weights is not None: # init the model weights with provided one model.load_weights(args.weights) if not args.testing: train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args) else: # as long as weights are given, will run testing if args.weights is None: print('No weights are provided. Will test using random initialized weights.') manipulate_latent(manipulate_model, (x_test, y_test), args) test(model=eval_model, data=(x_test, y_test), args=args) 2.网络结构 Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 28, 28, 1) 0 __________________________________________________________________________________________________ conv1 (Conv2D) (None, 20, 20, 256) 20992 input_1[0][0] __________________________________________________________________________________________________ primarycap_conv2d (Conv2D) (None, 6, 6, 256) 5308672 conv1[0][0] __________________________________________________________________________________________________ primarycap_reshape (Reshape) (None, 1152, 8) 0 primarycap_conv2d[0][0] __________________________________________________________________________________________________ primarycap_squash (Lambda) (None, 1152, 8) 0 primarycap_reshape[0][0] __________________________________________________________________________________________________ digitcaps (CapsuleLayer) (None, 8, 16) 1179648 primarycap_squash[0][0] __________________________________________________________________________________________________ input_2 (InputLayer) (None, 8) 0 __________________________________________________________________________________________________ mask_1 (Mask) (None, 128) 0 digitcaps[0][0] input_2[0][0] __________________________________________________________________________________________________ capsnet (Length) (None, 8) 0 digitcaps[0][0] __________________________________________________________________________________________________ decoder (Sequential) (None, 28, 28, 1) 1394960 mask_1[0][0] ================================================================================================== Total params: 7,904,272 Trainable params: 7,904,272 Non-trainable params: 0 ____________________________
此代码实在keras官方教程下修改而成:
https://blog.csdn.net/wyx100/article/details/80724501
用CapsNets做电能质量扰动分类(2019-08-05)的更多相关文章
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- 做一个logitic分类之鸢尾花数据集的分类
做一个logitic分类之鸢尾花数据集的分类 Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例.数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都 ...
- MyBatis 配置/注解 SQL CRUD 经典解决方案(2019.08.15持续更新)
本文旨在记录使用各位大神的经典解决方案. 2019.08.14 更新 Mybatis saveOrUpdate SelectKey非主键的使用 MyBatis实现SaveOrUpdate mybati ...
- http://tedhacker.top/2016/08/05/Spring%E7%BA%BF%E7%A8%8B%E6%B1%A0%E4%BD%BF%E7%94%A8%E6%96%B9%E6%B3%95/
http://tedhacker.top/2016/08/05/Spring%E7%BA%BF%E7%A8%8B%E6%B1%A0%E4%BD%BF%E7%94%A8%E6%96%B9%E6%B3%9 ...
- 新手C#类、对象、字段、方法的学习2018.08.05
类:具有相似属性和方法的对象的集合,如“人”是个类. 对象(实例):对象是具体的看得见摸得着的,如“张三”是“人”这个类的对象.(new Person()开辟了堆空间中,=开辟了栈空间,变量P存放在该 ...
- 【2019年05月20日】A股滚动市盈率PE历史新低排名
2010年01月01日 到 2019年05月20日 之间,滚动市盈率历史新低排名. 上市三年以上的公司, 2019年05月20日市盈率在300以下的公司. 1 - 阳光照明(SH600261) - 历 ...
- 2019.07.05 纪中_B
今日膜拜:czj大佬orz%%% 2019.07.05[NOIP提高组]模拟 B 组 今天做题的时候大概能判断出题人的考点,可是就是没学过...特别痛苦 T0:栈的定义,模拟就好了T1:感觉像是找规律 ...
- http://www.blogjava.net/xylz/archive/2013/08/05/402405.html
http://www.blogjava.net/xylz/archive/2013/08/05/402405.html
- 新手C#s.Split(),s.Substring(,)以及读取txt文件中的字符串的学习2018.08.05
s.split()用于字符串分割,具有多种重载方法,可以通过指定字符或字符串分割原字符串成为字符串数组. //s.Split()用于分割字符串为字符串数组,StringSplitOptions.Rem ...
随机推荐
- window open() 方法
open() 方法用于打开一个新的浏览器窗口或查找一个已命名的窗口. 语法 window.open(URL,name,specs,replace) 参数 说明 URL 可选.打开指定的页面的URL.如 ...
- bug提交遵循的规则
在提交缺陷时,需要遵循以下5个原则: 准确性:缺陷每个组成部分描述准确,不会产生误解,减少“异常”“正常”等模糊词的使用 完整性:复现该缺陷完整的步骤.截图.日志 一致性:按照一致的格式书写全部缺陷信 ...
- 2019牛客多校第五场 generator 1——广义斐波那契循环节&&矩阵快速幂
理论部分 二次剩余 在数论中,整数 $X$ 对整数 $p$ 的二次剩余是指 $X^2$ 除以 $p$ 的余数. 当存在某个 $X$,使得式子 $X^2 \equiv d(mod \ p)$ 成立时,称 ...
- 005_STM32程序移植之_RC522读卡模块
1. 测试环境:STM32C8T6 2. 测试模块:RC522读卡模块 3. 测试接口: RC522读卡模块: VCC------------------3.3V GND--------------- ...
- 五十七.分布式ELK平台、ES安装 、 扩展插件 、Kibana安装
1. ES集群安装 准备1台虚拟机 部署elasticsearch第一个节点 访问9200端口查看是否安装成功 1ELK是日志分析平台,不是一款软件,而是一整套解决方案,是三个软件产品的首字母缩写 ...
- 如何更改电脑ip
首先打开控制面板==>点击网络和internet==>点击网络和共享中心==>点击更改适配器设置==>右键无线连接或宽带连接(视情况而定)==>属性==>双击ipv ...
- MySQL查询top N记录
下面以查询每门课程分数最高的学生以及成绩为例,演示如何查询 top N记录.下图是测试数据,表结构和相关 insert 脚本见<常用SQL之日期格式化和查询重复数据>. 使用自连接[推荐] ...
- AcWing:241. 楼兰图腾(树状数组逆序对)
在完成了分配任务之后,西部314来到了楼兰古城的西部. 相传很久以前这片土地上(比楼兰古城还早)生活着两个部落,一个部落崇拜尖刀(‘V’),一个部落崇拜铁锹(‘∧’),他们分别用V和∧的形状来代表各自 ...
- Js 之复制到剪贴板 clipboard.js
一.下载 https://github.com/zenorocha/clipboard.js/archive/master.zip 二.Demo示例 <!DOCTYPE html> < ...
- Java安全(加密、摘要、签名、证书、SSL、HTTPS)
对于一般的开发人员来说,很少需要对安全领域内的基础技术进行深入的研究,但是鉴于日常系统开发中遇到的各种安全相关的问题,熟悉和了解这些安全技术的基本原理和使用场景还是非常必要的.本文将对非对称加密.数字 ...