tensorflow可以很方便的添加用户自定义的operator(如果不添加也可以采用sklearn的auc计算函数或者自己写一个
但是会在python执行,这里希望在graph中也就是c++端执行这个计算)

这里根据工作需要添加一个计算auc的operator,只给出最简单实现,后续高级功能还是参考官方wiki

https://www.tensorflow.org/versions/r0.7/how_tos/adding_an_op/index.html

注意tensorflow现在和最初的官方wiki有变化,原wiki貌似是需要重新bazel编译整个tensorflow,然后使用比如tf.user_op.auc这样。

目前wiki给出的方式>=0.6.0版本,采用plug-in的方式,更加灵活可以直接用g++编译一个so载入,解耦合,省去了编译tensorflow过程,即插即用。

 
 

首先auc的operator计算的文件

 
 

tensorflow/core/user_ops/auc.cc

 
 

/* Copyright 2015 Google Inc. All Rights Reserved.

 
 

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

 
 

http://www.apache.org/licenses/LICENSE-2.0

 
 

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

==============================================================================*/

 
 

// An auc Op.

 
 

#include
"tensorflow/core/framework/op.h"

#include
"tensorflow/core/framework/op_kernel.h"

 
 

using
namespace tensorflow;

using
std::vector;

//@TODO add weight as optional input

REGISTER_OP("Auc")

.Input("predicts: T1")

.Input("labels: T2")

.Output("z: float")

.Attr("T1: {float, double}")

.Attr("T2: {float, double}")

//.Attr("T1: {float, double}")

//.Attr("T2: {int32, int64}")

.SetIsCommutative()

.Doc(R"doc(

Given preidicts and labels output it's auc

)doc");

 
 

class
AucOp : public OpKernel {

public:

explicit
AucOp(OpKernelConstruction* context) : OpKernel(context) {}

 
 

template<typename
ValueVec>

void
index_sort(const
ValueVec& valueVec, vector<int>& indexVec)

{

indexVec.resize(valueVec.size());

for (size_t
i = 0; i < indexVec.size(); i++)

{

indexVec[i] = i;

}

std::sort(indexVec.begin(), indexVec.end(),

[&valueVec](const
int
l, const
int
r) { return
valueVec(l) > valueVec(r); });

}

 
 

void
Compute(OpKernelContext* context) override {

// Grab the input tensor

const Tensor& predicts_tensor = context->input(0);

const Tensor& labels_tensor = context->input(1);

auto
predicts = predicts_tensor.flat<float>(); //输入能接受float double那么这里如何都处理?

auto
labels = labels_tensor.flat<float>();

 
 

vector<int> indexes;

index_sort(predicts, indexes);

typedef
float
Float;

 
 

Float
oldFalsePos = 0;

Float
oldTruePos = 0;

Float
falsePos = 0;

Float
truePos = 0;

Float
oldOut = std::numeric_limits<Float>::infinity();

Float
result = 0;

 
 

for (size_t
i = 0; i < indexes.size(); i++)

{

int
index = indexes[i];

Float
label = labels(index);

Float
prediction = predicts(index);

Float
weight = 1.0;

//Pval3(label, output, weight);

if (prediction != oldOut) //存在相同值得情况是特殊处理的

{

result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);

oldOut = prediction;

oldFalsePos = falsePos;

oldTruePos = truePos;

}

if (label > 0)

truePos += weight;

else

falsePos += weight;

}

result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);

Float
AUC = result / (truePos * falsePos);

 
 

// Create an output tensor

Tensor* output_tensor = NULL;

TensorShape output_shape;

 
 

OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));

output_tensor->scalar<float>()() = AUC;

}

};

 
 

REGISTER_KERNEL_BUILDER(Name("Auc").Device(DEVICE_CPU), AucOp);

 
 

 
 

编译:

$cat gen-so.sh

 
 

TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')

TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')

i=$1

o=${i/.cc/.so}

g++ -std=c++11 -shared $i -o $o -I $TF_INC -l tensorflow_framework -L $TF_LIB -fPIC -Wl,-rpath $TF_LIB

 
 

$sh gen-so.sh auc.cc

会生成auc.so

 
 

使用的时候

auc_module = tf.load_op_library('auc.so')

#auc = tf.user_ops.auc #0.6.0之前的tensorflow 自定义op方式

auc = auc_module.auc

 
 

evaluate_op = auc(py_x, Y) #py_x is predicts, Y is labels

 
 

 
 

 
 

 
 

 
 

 
 

tensorflow添加自定义的auc计算operator的更多相关文章

  1. AUC计算 - 进阶操作

    首先AUC值是一个概率值,当你随机挑选一个正样本以及负样本,当前的分类算法根据计算得到的Score值将这个正样本排在负样本前面的概率就是AUC值,AUC值越大,当前分类算法越有可能将正样本排在负样本前 ...

  2. ROC 曲线,以及AUC计算方式

    ROC曲线: roc曲线:接收者操作特征(receiveroperating characteristic),roc曲线上每个点反映着对同一信号刺激的感受性. ROC曲线的横轴: 负正类率(false ...

  3. AUC计算 - 手把手步进操作

    2017-07-10 14:38:24 理论参考: 评估分类器性能的度量,像混淆矩阵.ROC.AUC等 http://www.cnblogs.com/suanec/p/5941630.html ROC ...

  4. 110、TensorFlow张量值的计算

    import tensorflow as tf #placeholders在没有提供具体值的时候不能使用eval方法来计算它的值 # 另外的建模方法可能会使得模型变得复杂 # TensorFlow 不 ...

  5. TensorFlow两种方式计算Cross Entropy

    sparse_softmax_cross_entropy_with_logits与softmax_cross_entropy_with_logits import tensorflow as tf y ...

  6. 学习笔记TF067:TensorFlow Serving、Flod、计算加速,机器学习评测体系,公开数据集

    TensorFlow Serving https://tensorflow.github.io/serving/ . 生产环境灵活.高性能机器学习模型服务系统.适合基于实际数据大规模运行,产生多个模型 ...

  7. tensorflow入门教程和底层机制简单解说——本质就是图计算,自动寻找依赖,想想spark机制就明白了

    简介 本章的目的是让你了解和运行 TensorFlow! 在开始之前, 让我们先看一段使用 Python API 撰写的 TensorFlow 示例代码, 让你对将要学习的内容有初步的印象. 这段很短 ...

  8. 机器学习的敲门砖:手把手教你TensorFlow初级入门

    摘要: 在开始使用机器学习算法之前,我们应该首先熟悉如何使用它们. 而本文就是通过对TensorFlow的一些基本特点的介绍,让你了解它是机器学习类库中的一个不错的选择. 本文由北邮@爱可可-爱生活  ...

  9. TensorFlow基础笔记(6) 图像风格化实验

    参考 http://blog.csdn.net/wspba/article/details/53994649 https://www.ctolib.com/AdaIN-style.html Ackno ...

随机推荐

  1. C#中中文编码的问题(StreamWriter和StreamReader默认编码)

    在使用StreamWriter和StreamReader时产生了这样的疑问,在不指定的情况下,他们使用什么编码方式? 查看MSDN,请看下图: 注意红色区域  这让我以为构造函数参数不同时使用不一样的 ...

  2. ssh项目将搜索条件进行联动

    <s:form namespace="/tb" action="tenderList" name="searchForm" id=&q ...

  3. [CG编程] 基本光照模型的实现与拓展以及常见光照模型解析

    0.前言 这篇文章写于去年的暑假.大二的假期时间多,小组便开发一个手机游戏的项目,开发过程中忙里偷闲地了解了Unity的shader编写,而CG又与shaderLab相似,所以又阅读了<CG教程 ...

  4. connect mysql

    #!/usr/bin/python# -*- coding:utf-8 -*- import MySQLdb db = MySQLdb.connect("127.0.0.1", & ...

  5. 第二轮冲刺-Runner站立会议06

    今天:解决连接问题 明天:编写日历界面 困难:暂无

  6. 20145204&20145212实验二报告

    实验二固件设计 步骤: 1.开发环境的配置,参考实验一 1.将实验代码拷贝到共享文件夹中. 2.在虚拟机中编译代码.对于多线程相关的代码,编译时需要加-lpthread的库.下载调试在超级终端中运行可 ...

  7. web service 学习

    是什么? 是一种远程调用技术,这种技术提供一些接口,这些接口实现让客户端和服务端进行通信和数据交换,并且让通信和交换与平台和开发语言无关.也可以说是提供了许多函数.客户端调用服务端的函数. 远程调用: ...

  8. 正则表达式之g标志,match和 exec

    1.g标志    g标志一般是与match和exec来连用,否则g标志没有太大的意义. 先来看一个带g标志的例子: var str = "tankZHang (231144) tank yi ...

  9. python 模块基础介绍

    从逻辑上组织代码,将一些有联系,完成特定功能相关的代码组织在一起,这些自我包含并且有组织的代码片段就是模块,将其他模块中属性附加到你的模块的操作叫做导入. 那些一个或多个.py文件组成的代码集合就称为 ...

  10. Gdb调试多进程程序

    Gdb调试多进程程序 程序经常使用fork/exec创建多进程程序.多进程程序有自己独立的地址空间,这是多进程调试首要注意的地方.Gdb功能强大,对调试多线程提供很多支持. 方法1:调试多进程最土的办 ...