本人代码库: https://github.com/beathahahaha/tensorflow-DeepFM-master-original

DeepFM原作者代码库: https://github.com/ChenglongChen/tensorflow-DeepFM

解析DeepFM代码 博客推荐:https://mp.weixin.qq.com/s/QrO48ZdP483TY_EnnWFhsQ

为了熟悉该代码的使用,我在example文件夹编写了一个test_1.py文件,可以直接运行

一、定义DeepFM 输入:

  需要train.csv(59列,有连续性数值,也有离散型数值,其中多分类都用的0,1,2,3表示),test.csv是kaggle比赛时需要输出的东西,非必要

  (参考该数据格式:https://www.kaggle.com/c/porto-seguro-safe-driver-prediction/data?select=train.csv)

二、定义DeepFM 输出:

  yy = dfm.predict(Xi_valid_, Xv_valid_) 得到一维np.array,其中数值为float代表概率值

tensorflow 建议1.14 gpu版本

如果自己要DIY的话,要注意哪些地方呢?

答:

1. config.py 里面的设置,和输入数据密切相关,要定义好离散型和连续型的列

2. 喂入的数据格式必须严格统一,注意修改test_1.py 中的列标签名字相关的内容(因此建议使用test_1.py 而不是原作者的main.py)

test_1.py:

import tensorflow as tf
from sklearn.metrics import roc_auc_score
import os
import sys import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.metrics import make_scorer
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score import config
from metrics import gini_norm
from DataReader import FeatureDictionary, DataParser sys.path.append("..")
from DeepFM import DeepFM def _load_data():
dfTrain = pd.read_csv(config.TRAIN_FILE)
dfTest = pd.read_csv(config.TEST_FILE) cols = [c for c in dfTrain.columns if c not in ["id", "target"]]
cols = [c for c in cols if (not c in config.IGNORE_COLS)] X_train = dfTrain[cols].values
y_train = dfTrain["target"].values
X_test = dfTest[cols].values
ids_test = dfTest["id"].values
cat_features_indices = [i for i, c in enumerate(cols) if c in config.CATEGORICAL_COLS] return dfTrain, dfTest, X_train, y_train, X_test, ids_test, cat_features_indices def _run_base_model_dfm(dfTrain, dfTest, folds, dfm_params):
fd = FeatureDictionary(dfTrain=dfTrain, dfTest=dfTest,
numeric_cols=config.NUMERIC_COLS,
ignore_cols=config.IGNORE_COLS)
data_parser = DataParser(feat_dict=fd)
Xi_train, Xv_train, y_train = data_parser.parse(df=dfTrain, has_label=True)
Xi_test, Xv_test, ids_test = data_parser.parse(df=dfTest) dfm_params["feature_size"] = fd.feat_dim
dfm_params["field_size"] = len(Xi_train[0]) y_train_meta = np.zeros((dfTrain.shape[0], 1), dtype=float)
y_test_meta = np.zeros((dfTest.shape[0], 1), dtype=float)
_get = lambda x, l: [x[i] for i in l]
gini_results_cv = np.zeros(len(folds), dtype=float)
gini_results_epoch_train = np.zeros((len(folds), dfm_params["epoch"]), dtype=float)
gini_results_epoch_valid = np.zeros((len(folds), dfm_params["epoch"]), dtype=float)
for i, (train_idx, valid_idx) in enumerate(folds):
# k折交叉,每一折中的fit中,含有epoch轮训练,每一次epoch拆分了batch来喂入
Xi_train_, Xv_train_, y_train_ = _get(Xi_train, train_idx), _get(Xv_train, train_idx), _get(y_train, train_idx)
Xi_valid_, Xv_valid_, y_valid_ = _get(Xi_train, valid_idx), _get(Xv_train, valid_idx), _get(y_train, valid_idx) dfm = DeepFM(**dfm_params)
dfm.fit(Xi_train_, Xv_train_, y_train_, Xi_valid_, Xv_valid_, y_valid_) # fit中包含对train和valid的评估 yy = dfm.predict(Xi_valid_, Xv_valid_)
# print("type(yy):",type(yy))
# print("type(y_valid_):", type(y_valid_)) # print("yy.shape:",yy.shape) #yy : array
# print("y_valid_.shape:", y_valid_.shape) #y_valid_ : list #print("yy:", yy) # 原始的predict出来的是概率值
for index in range(len(yy)):
if (yy[index] <= 0.5):
yy[index] = 0
else:
yy[index] = 1 #print("y_valid_:", y_valid_) print("accuracy_score(y_valid_, yy):", accuracy_score(y_valid_, yy)) y_train_meta[valid_idx, 0] = yy y_test_meta[:, 0] += dfm.predict(Xi_test, Xv_test) y_test_meta /= float(len(folds)) return y_train_meta, y_test_meta # params
dfm_params = {
"use_fm": True,
"use_deep": True,
"embedding_size": 8,
"dropout_fm": [1.0, 1.0],
"deep_layers": [32, 32],
"dropout_deep": [0.5, 0.5, 0.5],
"deep_layers_activation": tf.nn.relu,
"epoch": 10,
"batch_size": 1024,
"learning_rate": 0.001,
"optimizer_type": "adam",
"batch_norm": 1,
"batch_norm_decay": 0.995,
"l2_reg": 0.01,
"verbose": True,
"eval_metric": roc_auc_score,
"random_seed": 2017
} dfTrain, dfTest, X_train, y_train, X_test, ids_test, cat_features_indices = _load_data() folds = list(StratifiedKFold(n_splits=config.NUM_SPLITS, shuffle=True,
random_state=config.RANDOM_SEED).split(X_train, y_train)) y_train_dfm, y_test_dfm = _run_base_model_dfm(dfTrain, dfTest, folds, dfm_params) print("over") # Xi_train, Xv_train, y_train = prepare(...)
# Xi_valid, Xv_valid, y_valid = prepare(...)

DeepFM——tensorflow代码改编的更多相关文章

  1. tensorflow 代码阅读

    具体实现: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/framework 『深度长文』Tensorflo ...

  2. 关于使用实验室服务器的GPU以及跑上TensorFlow代码

    连接服务器 Windows - XShell XFtp SSH 通过SSH来连接实验室的服务器 使用SSH连接已经不陌生了 github和OS课设都经常使用 目前使用 192.168.7.169 使用 ...

  3. 条件随机场(crf)及tensorflow代码实例

    对于条件随机场的学习,我觉得应该结合HMM模型一起进行对比学习.首先浏览HMM模型:https://www.cnblogs.com/pinking/p/8531405.html 一.定义 条件随机场( ...

  4. 如何高效的学习 TensorFlow 代码? 以及TensorFlow相关的论文

    https://www.zhihu.com/question/41667903 源码分析 http://www.cnblogs.com/yao62995/p/5773578.html 如何贡献Tens ...

  5. Transformer解析与tensorflow代码解读

    本文是针对谷歌Transformer模型的解读,根据我自己的理解顺序记录的. 另外,针对Kyubyong实现的tensorflow代码进行解读,代码地址https://github.com/Kyuby ...

  6. 深度学习之卷积神经网络CNN及tensorflow代码实例

    深度学习之卷积神经网络CNN及tensorflow代码实例 什么是卷积? 卷积的定义 从数学上讲,卷积就是一种运算,是我们学习高等数学之后,新接触的一种运算,因为涉及到积分.级数,所以看起来觉得很复杂 ...

  7. 深度学习之卷积神经网络CNN及tensorflow代码实现示例

    深度学习之卷积神经网络CNN及tensorflow代码实现示例 2017年05月01日 13:28:21 cxmscb 阅读数 151413更多 分类专栏: 机器学习 深度学习 机器学习   版权声明 ...

  8. 运行TensorFlow代码时报错

    运行TensorFlow代码时报错 错误信息ImportError: libcublas.so.10.0: cannot open shared object file 原因:TensorFlow版本 ...

  9. 利用VGG19实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

随机推荐

  1. 主题包含一张index.html

    有半年之久没有更新新作品了,但这个小小领地我并没有忘记,我会坚持下去,一直在这等你,等你的每次回眸,感恩你的每次驻足,这已经足够成为我坚守的动力和理由,尽管现在有很多不足和不尽人意,也没很多的时间管理 ...

  2. shell 脚本之set 命令(转)

    服务器的开发和管理离不开 Bash 脚本,掌握它需要学习大量的细节. set命令是 Bash 脚本的重要环节,却常常被忽视,导致脚本的安全性和可维护性出问题.本文介绍它的基本用法,让你可以更安心地使用 ...

  3. 从ceph对象中提取RBD中的指定文件

    前言 之前有个想法,是不是有办法找到rbd中的文件与对象的关系,想了很久但是一直觉得文件系统比较复杂,在fs 层的东西对ceph来说是透明的,并且对象大小是4M,而文件很小,可能在fs层进行了合并,应 ...

  4. 流量控制--5.Classless Queuing Disciplines (qdiscs)

    Classless Queuing Disciplines (qdiscs) 本文涉及的队列规则(Qdisc)都可以作为接口上的主qdisc,或作为一个classful qdiscs的叶子类.这些是L ...

  5. 网络发布工具 Apache/Nginx

    四大主流发布服务器 注:发布服务器的背后都是socket套接字 1.Apache阿帕奇 - 多进程 2.IIS -多线程 3.Nginx (engine x)(新) -支持异步IO,是现在最快的发布服 ...

  6. SQL Server DATEDIFF() 函数用法

    定义和用法 DATEDIFF() 函数返回两个日期之间的时间,例如计算年龄大小. DATEDIFF(datepart,startdate,enddate)startdate 和 enddate 参数是 ...

  7. Maven项目关系

    Maven是一个项目管理工具,它包含了一个项目对象模型 (Project Object Model),其中最重要的就是POM文件,可以指定项目类型,项目关系等信息,maven项目之间有三种关系. 依赖 ...

  8. Apache Shiro 反序列化漏洞复现(CVE-2016-4437)

    漏洞描述 Apache Shiro是一个Java安全框架,执行身份验证.授权.密码和会话管理.只要rememberMe的AES加密密钥泄露,无论shiro是什么版本都会导致反序列化漏洞. 漏洞原理 A ...

  9. 面试官:就问个Spring容器初始化和Bean对象的创建,你讲一小时了

    前言 spring作为一个容器,可以管理对象的生命周期.对象与对象之间的依赖关系.可以通过配置文件,来定义对象,以及设置其与其他对象的依赖关系. main测试类 public static void ...

  10. 2020阿里Java面试题目大汇总,看看你离阿里还有多远,附答案!

    前言 首先说一下情况,我大概我是从去年12月份开始看书学习,到今年的6月份,一直学到看大家的面经基本上百分之90以上都会,我就在5月份开始投简历,边面试边补充基础知识等.也是有些辛苦.终于是在前不久拿 ...