pytorch中调用C进行扩展
pytorch中调用C进行扩展,使得某些功能在CPU上运行更快;
第一步:编写头文件
/* src/my_lib.h */
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
第二步:编写源文件
/* src/my_lib.c */
#include <TH/TH.h> int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
if (!THFloatTensor_isSameSizeAs(input1, input2))
return ;
THFloatTensor_resizeAs(output, input1);
THFloatTensor_cadd(output, input1, 1.0, input2);
return ;
} int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
THFloatTensor_resizeAs(grad_input, grad_output);
THFloatTensor_fill(grad_input, );
return ;
}
注意:头文件TH就是pytorch底层代码的接口头文件,它是CPU模式,GPU下则为THC;

第三步:在同级目录下创建一个.py文件(比如叫“build.py”)
该文件用于对该C扩展模块进行编译(使用torch.util.ffi模块进行扩展编译);
# build.py
from torch.utils.ffi import create_extension
ffi = create_extension(
name='_ext.my_lib', # 输出文件地址及名称
headers='src/my_lib.h', # 编译.h文件地址及名称
sources=['src/my_lib.c'], # 编译.c文件地址及名称
with_cuda=False # 不使用cuda
)
ffi.build()
第四步:编写.py脚本调用编译好的C扩展模块
import torch
from torch.autograd import Function
from _ext import my_lib
import torch.nn as nn class MyAddFunction(Function):
def forward(self, input1, input2):
output = torch.FloatTensor()
my_lib.my_lib_add_forward(input1, input2, output)
return output def backward(self, grad_output):
grad_input = torch.FloatTensor()
my_lib.my_lib_add_backward(grad_input, grad_output)
return grad_input class MyAddModule(nn.Module):
def forward(self, input1, input2):
return MyAddFunction()(input1, input2) class MyNetWork(nn.Module):
def __init__(self):
super(MyNetWork, self).__init__()
self.add = MyAddModule() def forward(self, input1, input2):
return self.add(input1, input2) model = MyNetWork()
input1, input2 = torch.randn(5, 5), torch.randn(5, 5)
print(model(input1, input2))
print(input1 + input2)
至此,用这个简单的例子抛砖引玉~
pytorch中调用C进行扩展的更多相关文章
- tp中调用PHP系统扩展类
例如使用Redis扩展类: use Reids; $redis = new Redis();
- PyTorch中的C++扩展
今天要聊聊用 PyTorch 进行 C++ 扩展. 在正式开始前,我们需要了解 PyTorch 如何自定义module.这其中,最常见的就是在 python 中继承torch.nn.Module,用 ...
- iOS 中 h5 页面 iframe 调用高度自扩展问题及解决
开发需求需要在 h5 中用 iframe 中调用一个其他公司开发的 html 页面. 简单的插入 <iframe /> 并设置宽高后,发现在 Android 手机浏览器上打开可以正常运行, ...
- C#中如果类的扩展方法和类本身的方法签名相同,那么会优先调用类本身的方法
新建一个.NET Core项目,假如我们有如下代码: using System; namespace MethodOverload { static class DemoExtension { pub ...
- pytorch中使用cuda扩展
以下面这个例子作为教程,实现功能是element-wise add: (pytorch中想调用cuda模块,还是用另外使用C编写接口脚本) 第一步:cuda编程的源文件和头文件 // mathutil ...
- Unity中调用Windows窗口句柄以及根据需求设置并且解决扩展屏窗体显示错乱/位置错误的Bug
问题背景: 现在在搞PC端应用开发,我们开发中需要调用系统的窗口以及需要最大化最小化,缩放窗口拖拽窗口,以及设置窗口位置,去边框等功能 解决根据: 使用user32.dll解决 具体功能: Unity ...
- Pytorch中RoI pooling layer的几种实现
Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...
- WebApi接口 - 如何在应用中调用webapi接口
很高兴能再次和大家分享webapi接口的相关文章,本篇将要讲解的是如何在应用中调用webapi接口:对于大部分做内部管理系统及类似系统的朋友来说很少会去调用别人的接口,因此可能在这方面存在一些困惑,希 ...
- Mybatis中SqlMapper配置的扩展与应用(1)
奋斗了好几个晚上调试程序,写了好几篇博客,终于建立起了Mybatis配置的扩展机制.虽然扩展机制是重要的,然而如果没有真正实用的扩展功能,那也至少是不那么鼓舞人心的,这篇博客就来举几个扩展的例子. 这 ...
随机推荐
- 使用Cloudera Manager添加Sentry服务
使用Cloudera Manager添加Sentry服务 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.通过CM添加Sentry服务 1>.点击添加服务进入CM服务安装向 ...
- K8s基本概念入门
序言 没等到风来,绵绵小雨,所以写个随笔,聊聊k8s的基本概念. k8s是一个编排容器的工具,其实也是管理应用的全生命周期的一个工具,从创建应用,应用的部署,应用提供服务,扩容缩容应用,应用更新,都非 ...
- Uva1349Optimal Bus Route Design(二分图最佳完美匹配)(最小值)
题意: 给定n个点的有向图问,问能不能找到若干个环,让所有点都在环中,且让权值最小,KM算法求最佳完美匹配,只不过是最小值,所以把边权变成负值,输出时将ans取负即可 这道题是在VJ上交的 #incl ...
- docker学习4-docker安装mysql环境
前言 docker安装mysql环境非常方便,简单的几步操作就可以了 拉取mysql镜像 先拉取mysql的镜像,可以在docker的镜像仓库找到不同TAG标签的版本https://hub.docke ...
- LeetCode 1046. Last Stone Weight
原题链接在这里:https://leetcode.com/problems/last-stone-weight/ 题目: We have a collection of rocks, each roc ...
- SpringBoot第二节(SpringBoot整合Mybatis)
1.创建Spring Initiallizr项目 一直点击下一步 2.引入依赖 <dependencies> <dependency> <groupId>org.s ...
- linux 出错 “INFO: task java: xxx blocked for more than 120 seconds.” 的3种解决方案
1 问题描述 最近搭建的一个linux最小系统在运行到241秒时在控制台自动打印如下图信息,并且以后每隔120秒打印一次. 仔细阅读打印信息发现关键信息是“hung_task_timeout_secs ...
- 使用for循环签到嵌套制作直角三角形
注意代码的运行顺序: for(i = 0 ; i<9 ; i++){ for(j = 0 ; j<i-1 ; j++){ document.write("*")//** ...
- [Gradle] 解决高德 jar 包打包到 aar 后 jar 包中的 assets 内容丢失
问题描述 将高德 SDK 的 jar 包放到 android library project libs 目录下,发布为 aar 包后,发现高德 jar 包中的 assets 目录下的内容不见了 原因见 ...
- 【cf contest 1119 G】Get Ready for the Battle
题目 你有\(n\)个士兵,需要将他们分成\(m\)组,每组可以为0: 现在这些士兵要去攻打\(m\)个敌人,每个敌人的生命值为\(hp_i\) : 一轮游戏中一组士兵选定一个攻打的敌人,敌人生命值- ...