吴裕雄--天生自然 PYTHON数据分析:基于Keras的CNN分析太空深处寻找系外行星数据
#We import libraries for linear algebra, graphs, and evaluation of results
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, roc_auc_score
from scipy.ndimage.filters import uniform_filter1d
#Keras is a high level neural networks library, based on either tensorflow or theano
from keras.models import Sequential, Model
from keras.layers import Conv1D, MaxPool1D, Dense, Dropout, Flatten, BatchNormalization, Input, concatenate, Activation
from keras.optimizers import Adam
INPUT_LIB = 'F:\\kaggleDataSet\\kepler-labelled\\'
raw_data = np.loadtxt(INPUT_LIB + 'exoTrain.csv', skiprows=1, delimiter=',')
x_train = raw_data[:, 1:]
y_train = raw_data[:, 0, np.newaxis] - 1.
raw_data = np.loadtxt(INPUT_LIB + 'exoTest.csv', skiprows=1, delimiter=',')
x_test = raw_data[:, 1:]
y_test = raw_data[:, 0, np.newaxis] - 1.
del raw_data
x_train = ((x_train - np.mean(x_train, axis=1).reshape(-1,1))/ np.std(x_train, axis=1).reshape(-1,1))
x_test = ((x_test - np.mean(x_test, axis=1).reshape(-1,1)) / np.std(x_test, axis=1).reshape(-1,1))
x_train = np.stack([x_train, uniform_filter1d(x_train, axis=1, size=200)], axis=2)
x_test = np.stack([x_test, uniform_filter1d(x_test, axis=1, size=200)], axis=2)
model = Sequential()
model.add(Conv1D(filters=8, kernel_size=11, activation='relu', input_shape=x_train.shape[1:]))
model.add(MaxPool1D(strides=4))
model.add(BatchNormalization())
model.add(Conv1D(filters=16, kernel_size=11, activation='relu'))
model.add(MaxPool1D(strides=4))
model.add(BatchNormalization())
model.add(Conv1D(filters=32, kernel_size=11, activation='relu'))
model.add(MaxPool1D(strides=4))
model.add(BatchNormalization())
model.add(Conv1D(filters=64, kernel_size=11, activation='relu'))
model.add(MaxPool1D(strides=4))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
def batch_generator(x_train, y_train, batch_size=32):
"""
Gives equal number of positive and negative samples, and rotates them randomly in time
"""
half_batch = batch_size // 2
x_batch = np.empty((batch_size, x_train.shape[1], x_train.shape[2]), dtype='float32')
y_batch = np.empty((batch_size, y_train.shape[1]), dtype='float32') yes_idx = np.where(y_train[:,0] == 1.)[0]
non_idx = np.where(y_train[:,0] == 0.)[0] while True:
np.random.shuffle(yes_idx)
np.random.shuffle(non_idx) x_batch[:half_batch] = x_train[yes_idx[:half_batch]]
x_batch[half_batch:] = x_train[non_idx[half_batch:batch_size]]
y_batch[:half_batch] = y_train[yes_idx[:half_batch]]
y_batch[half_batch:] = y_train[non_idx[half_batch:batch_size]] for i in range(batch_size):
sz = np.random.randint(x_batch.shape[1])
x_batch[i] = np.roll(x_batch[i], sz, axis = 0) yield x_batch, y_batch
#Start with a slightly lower learning rate, to ensure convergence
model.compile(optimizer=Adam(1e-5), loss = 'binary_crossentropy', metrics=['accuracy'])
hist = model.fit_generator(batch_generator(x_train, y_train, 32),
validation_data=(x_test, y_test),
verbose=0, epochs=5,
steps_per_epoch=x_train.shape[1]//32)
#Then speed things up a little
model.compile(optimizer=Adam(4e-5), loss = 'binary_crossentropy', metrics=['accuracy'])
hist = model.fit_generator(batch_generator(x_train, y_train, 32),
validation_data=(x_test, y_test),
verbose=2, epochs=40,
steps_per_epoch=x_train.shape[1]//32)
plt.plot(hist.history['loss'], color='b')
plt.plot(hist.history['val_loss'], color='r')
plt.show()
plt.plot(hist.history['acc'], color='b')
plt.plot(hist.history['val_acc'], color='r')
plt.show()
non_idx = np.where(y_test[:,0] == 0.)[0]
yes_idx = np.where(y_test[:,0] == 1.)[0]
y_hat = model.predict(x_test)[:,0]
plt.plot([y_hat[i] for i in yes_idx], 'bo')
plt.show()
plt.plot([y_hat[i] for i in non_idx], 'ro')
plt.show()
y_true = (y_test[:, 0] + 0.5).astype("int")
fpr, tpr, thresholds = roc_curve(y_true, y_hat)
plt.plot(thresholds, 1.-fpr)
plt.plot(thresholds, tpr)
plt.show()
crossover_index = np.min(np.where(1.-fpr <= tpr))
crossover_cutoff = thresholds[crossover_index]
crossover_specificity = 1.-fpr[crossover_index]
print("Crossover at {0:.2f} with specificity {1:.2f}".format(crossover_cutoff, crossover_specificity))
plt.plot(fpr, tpr)
plt.show()
print("ROC area under curve is {0:.2f}".format(roc_auc_score(y_true, y_hat)))
false_positives = np.where(y_hat * (1. - y_test) > 0.5)[0]
for i in non_idx:
if y_hat[i] > crossover_cutoff:
print(i)
plt.plot(x_test[i])
plt.show()
吴裕雄--天生自然 PYTHON数据分析:基于Keras的CNN分析太空深处寻找系外行星数据的更多相关文章
- 吴裕雄--天生自然 python数据分析:健康指标聚集分析(健康分析)
# This Python 3 environment comes with many helpful analytics libraries installed # It is defined by ...
- 吴裕雄--天生自然 PYTHON数据分析:钦奈水资源管理分析
df = pd.read_csv("F:\\kaggleDataSet\\chennai-water\\chennai_reservoir_levels.csv") df[&quo ...
- 吴裕雄--天生自然 python数据分析:基于Keras使用CNN神经网络处理手写数据集
import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.image as mp ...
- 吴裕雄--天生自然 PYTHON数据分析:糖尿病视网膜病变数据分析(完整版)
# This Python 3 environment comes with many helpful analytics libraries installed # It is defined by ...
- 吴裕雄--天生自然 PYTHON数据分析:所有美国股票和etf的历史日价格和成交量分析
# This Python 3 environment comes with many helpful analytics libraries installed # It is defined by ...
- 吴裕雄--天生自然 python数据分析:葡萄酒分析
# import pandas import pandas as pd # creating a DataFrame pd.DataFrame({'Yes': [50, 31], 'No': [101 ...
- 吴裕雄--天生自然 PYTHON数据分析:人类发展报告——HDI, GDI,健康,全球人口数据数据分析
import pandas as pd # Data analysis import numpy as np #Data analysis import seaborn as sns # Data v ...
- 吴裕雄--天生自然 python数据分析:医疗费数据分析
import numpy as np import pandas as pd import os import matplotlib.pyplot as pl import seaborn as sn ...
- 吴裕雄--天生自然 PYTHON数据分析:医疗数据分析
import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.rea ...
随机推荐
- 实现迭代器(\_\_next\_\_和\_\_iter\_\_)
目录 实现迭代器(__next__和__iter__) 一.简单示例 二.StopIteration异常版 三.模拟range 四.斐波那契数列 实现迭代器(__next__和__iter__) 一. ...
- Long型转ZonedDateTime型
/** * 将Long类型转化成0 * @author yk * @param time * @return */public static ZonedDateTime toZonedDateTime ...
- goweb-处理静态资源
处理静态文件 对于 HTML 页面中的 css 以及 js 等静态文件,需要使用使用 net/http 包下的以下 方法来处理 1) StripPrefix 函数 2) FileServer 函数 3 ...
- linux中ftp中文名乱码问题
问题触发环境 1. java中使用org.apache.commons.net.ftp.FTPClient包 2. 通过chrome浏览器的file标签上传文件 3. 在windows上部署的File ...
- 4. 监控利器nagios手把手企业级实战第三部
1.nagios图形监控显示和管理服务器 虽然能显示,能报警.但是我们企业工作中需要一个历史趋势图. nagios只开放核心,插件是单独的形式,图像也一样,是插件或者整合的方式.所以可能看起来很多,这 ...
- java第三方工具包
--搜集于网络 1.Apache POI 处理office文档用到的2. IText PDF操作类库 3.Java Base64 Base64编码类库 4.Commons-lang 对应java sd ...
- 14 微服务电商【黑马乐优商城】:day03-springcloud(Zuul网关)
本项目的笔记和资料的Download,请点击这一句话自行获取. day01-springboot(理论篇) :day01-springboot(实践篇) day02-springcloud(理论篇一) ...
- Java之接口(java8的新特性)
public class SubClassTest { public static void main(String[] args) { SubClass s = new SubClass(); // ...
- Opencv笔记(八)——图像上的算数运算
学习目标: 学习图像上的算术运算,加法,减法,位运算等. 学习函数cv2.add(),cv2.addWeighted() 等. 一.图像的加法 你可以使用函数 cv2.add() 将两幅图像进行加法运 ...
- GB35658较796新增检测项部标平台
GB35658较796新增检测项部标平台总共有113项,总结归类如下:1 报表导出 支持excel格式的报表导出 对查询.统计报表提供excel格式的报表导出 必选: 2 ...