1 WMMA (Warp-level Matrix Multiply Accumulate) API

对于计算能力在7.0及以上的CUDA设备,可以使用CUDA C++ API调用Tensor Core,支持形如D = AB + C的混合精度的矩阵乘运算。
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment;

void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
  • fragment:Tensor Core数据存储类,支持matrix_a、matrix_b和accumulator
  • load_matrix_sync:Tensor Core数据加载API,支持将矩阵数据从global memory或shared memory加载到fragment
  • store_matrix_sync:Tensor Core结果存储API,支持将计算结果从fragment存储到global memory或shared memory
  • fill_fragment:fragment填充API,支持常数值填充
  • mma_sync:Tensor Core矩阵乘计算API,支持D = AB + C或者C = AB + C

2 示例

以m16n16k16为例,实现HGEMM:C = AB,其中矩阵A(M * K,row major)、B(K * N,col major)和C(M * N,row major)的精度均为FP16。首先我们看如何使用CUDA Core写HGEMM naive算法。

2.1 CUDA Core

按照每个线程计算矩阵C中的一个元素来构建naive kernel,首先确定当前线程处理矩阵C的元素坐标,再遍历K并直接从global memory中加载所需A、B矩阵元素到寄存器参与计算,最后将计算结果从寄存器直接写回矩阵C。所有block计算完成之后即可得到矩阵C。这个例子不能说简单,只能说技术含量不高,不过我们只是为了对比。
#define DIV_CEIL(x, y) (((x) + (y) - 1) / (y))

__global__ void naiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
size_t N, size_t K) {
size_t row = threadIdx.x + blockDim.x * blockIdx.x;
size_t col = threadIdx.y + blockDim.y * blockIdx.y;
if (row < M && col < N) {
half tmp = 0.0;
for (size_t i = 0; i < K; ++i) {
tmp += A[row * K + i] * B[i + col * K];
}
C[row * N + col] = tmp;
}
} void hgemmNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
dim3 block(16, 16);
dim3 grid(DIV_CEIL(M, block.x), DIV_CEIL(N, block.y)); naiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.2 Tensor Core

我们再来看如何用WMMA API来构建naive kernel,参考cuda sample。与CUDA Core naive不同的是,WMMA需要按照每个warp处理一个矩阵C的WMMA_M * WMMA_N大小的tile的思路来构建,因为Tensor Core的计算层级是warp级别,计算的矩阵元素也是二维的。接下来,与CUDA Core naive的处理思路一致,首先确定当前warp处理矩阵C的tile坐标,声明计算tilie所需的fragment,再以WMMA_K为步长遍历K并直接从global memory中加载所需A、B矩阵tile到fragment参与计算,最后将计算结果从fragment直接写回矩阵C。所有block计算完成之后即可得到矩阵C。
值得注意的是,load_matrix_sync和store_matrix_sync都是按stride访问矩阵元素。
#include <mma.h>

#define WARP_SIZE 32

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16 using namespace nvcuda; __global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
size_t N, size_t K) {
size_t warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;
size_t warpN = (blockIdx.y * blockDim.y + threadIdx.y); wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> c_frag; wmma::fill_fragment(c_frag, 0.0f); for (size_t i = 0; i < K; i += WMMA_K) {
size_t aCol = i;
size_t aRow = warpM * WMMA_M;
size_t bCol = warpN * WMMA_N;
size_t bRow = i; if (aRow < M && aCol < K && bRow < K && bCol < N) {
wmma::load_matrix_sync(a_frag, A + aCol + aRow * K, K);
wmma::load_matrix_sync(b_frag, B + bRow + bCol * K, K); wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
} size_t cCol = warpN * WMMA_N;
size_t cRow = warpM * WMMA_M; if (cRow < M && cCol < N) {
wmma::store_matrix_sync(C + cCol + cRow * N, c_frag, N, wmma::mem_row_major);
}
} void hgemmWmmaNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
dim3 block(128, 4);
dim3 grid((M - 1) / (WMMA_M * block.x / WARP_SIZE) + 1, (N - 1) / (WMMA_N * block.y) + 1); wmmaNaiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.3 区别

从上述两个naive kernel的代码来看调用CUDA Core和Tensor Core的区别如下:
  • 计算层级:CUDA Core是线程级别,Tensor Core是warp级别
  • 计算维度:CUDA Core是一维逐点计算,Tensor Core是二维逐tile计算
  • 计算依赖:WMMA调用Tensor Core需要借助数据存储类fragment,CUDA Core不需要借助其他

3 底层代码

我们再对上述WMMA naive kernel做进一步探索,看一下它在RTX A6000(sm_86,CUDA 11.3)上对应的PTX和SASS。

3.1 PTX

dump出对应的PTX代码如下,好像不那么简单了。
.visible .entry _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm(
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5
)
{
.reg .pred %p<8>;
.reg .b16 %rs<2>;
.reg .f32 %f<2>;
.reg .b32 %r<58>;
.reg .b64 %rd<28>; ld.param.u64 %rd9, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0];
ld.param.u64 %rd10, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1];
ld.param.u64 %rd11, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2];
ld.param.u64 %rd14, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3];
ld.param.u64 %rd12, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4];
ld.param.u64 %rd13, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5];
mov.u32 %r19, %ntid.x;
mov.u32 %r20, %ctaid.x;
mov.u32 %r21, %tid.x;
mad.lo.s32 %r22, %r20, %r19, %r21;
mov.u32 %r23, %ntid.y;
mov.u32 %r24, %ctaid.y;
mov.u32 %r25, %tid.y;
mad.lo.s32 %r26, %r24, %r23, %r25;
mov.f32 %f1, 0f00000000; { cvt.rn.f16.f32 %rs1, %f1;} mov.b32 %r50, {%rs1, %rs1};
mul.wide.u32 %rd1, %r26, 16;
shr.u32 %r27, %r22, 1;
and.b32 %r28, %r27, 2147483632;
cvt.u64.u32 %rd2, %r28;
setp.lt.u64 %p2, %rd2, %rd14;
setp.lt.u64 %p3, %rd1, %rd12;
and.pred %p1, %p2, %p3;
setp.eq.s64 %p4, %rd13, 0;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;
@%p4 bra $L__BB0_5; mul.lo.s64 %rd3, %rd2, %rd13;
cvt.u32.u64 %r2, %rd13;
mul.lo.s64 %rd4, %rd1, %rd13;
cvta.to.global.u64 %rd5, %rd10;
cvta.to.global.u64 %rd6, %rd9;
mov.u64 %rd27, 0;
not.pred %p5, %p1;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50; $L__BB0_2:
@%p5 bra $L__BB0_4; add.s64 %rd16, %rd27, %rd3;
shl.b64 %rd17, %rd16, 1;
add.s64 %rd18, %rd6, %rd17;
wmma.load.a.sync.aligned.row.m16n16k16.global.f16 {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, [%rd18], %r2;
add.s64 %rd19, %rd27, %rd4;
shl.b64 %rd20, %rd19, 1;
add.s64 %rd21, %rd5, %rd20;
wmma.load.b.sync.aligned.col.m16n16k16.global.f16 {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, [%rd21], %r2;
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 {%r53, %r52, %r51, %r50}, {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, {%r53, %r52, %r51, %r50}; $L__BB0_4:
add.s64 %rd27, %rd27, 16;
setp.lt.u64 %p6, %rd27, %rd13;
@%p6 bra $L__BB0_2; $L__BB0_5:
not.pred %p7, %p1;
@%p7 bra $L__BB0_7; mul.lo.s64 %rd22, %rd2, %rd12;
add.s64 %rd23, %rd22, %rd1;
cvta.to.global.u64 %rd24, %rd11;
shl.b64 %rd25, %rd23, 1;
add.s64 %rd26, %rd24, %rd25;
cvt.u32.u64 %r45, %rd12;
wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%rd26], {%r53, %r52, %r51, %r50}, %r45; $L__BB0_7:
ret; }

不过我们主要关注WMMA相关的PTX指令,如下所示。可以看到这里正是Nvidia提供的WMMA PTX指令来调用Tensor Core,所以无论是使用WMMA API编程,还是使用WMMA PTX指令编程,底层差别不会太大。

wmma.load.a.sync.aligned.row.m16n16k16.global.f16
wmma.load.b.sync.aligned.col.m16n16k16.global.f16
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
wmma.store.d.sync.aligned.row.m16n16k16.global.f16

3.2 SASS

进一步dump出对应的SASS代码,似乎也不简单。
      IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28]
S2R R0, SR_CTAID.X
ISETP.NE.U32.AND P2, PT, RZ, c[0x0][0x188], PT
ULDC.64 UR4, c[0x0][0x118]
CS2R R8, SRZ
S2R R10, SR_CTAID.Y
ISETP.NE.AND.EX P2, PT, RZ, c[0x0][0x18c], PT, P2
S2R R5, SR_TID.Y
S2R R3, SR_TID.X
IMAD R10, R10, c[0x0][0x4], R5
IMAD R0, R0, c[0x0][0x0], R3
IMAD.WIDE.U32 R10, R10, 0x10, RZ
CS2R R2, SRZ
SHF.R.U32.HI R0, RZ, 0x1, R0
ISETP.GE.U32.AND P0, PT, R10, c[0x0][0x180], PT
LOP3.LUT R13, R0, 0x7ffffff0, RZ, 0xc0, !PT
ISETP.GE.U32.AND.EX P0, PT, R11, c[0x0][0x184], PT, P0
ISETP.LT.U32.AND P1, PT, R13, c[0x0][0x178], PT
ISETP.LT.U32.AND.EX P0, PT, RZ, c[0x0][0x17c], !P0, P1
@!P2 BRA 0x7f1eaefc0160
BSSY B0, 0x7f1eaefc0160
IMAD.MOV.U32 R0, RZ, RZ, RZ
CS2R R8, SRZ
IMAD.MOV.U32 R15, RZ, RZ, RZ
IMAD.MOV.U32 R2, RZ, RZ, RZ
BSSY B1, 0x7f1eaefc0100
@!P0 BRA 0x7f1eaefc00f0
S2R R16, SR_LANEID
IMAD R17, R11, c[0x0][0x188], RZ
IMAD.MOV.U32 R14, RZ, RZ, R0
IMAD.MOV.U32 R23, RZ, RZ, c[0x0][0x188]
IMAD.WIDE.U32 R6, R10, c[0x0][0x188], R14
SHF.R.U32.HI R12, RZ, 0x1, R23
IMAD R17, R10, c[0x0][0x18c], R17
LEA R21, P2, R6, c[0x0][0x168], 0x1
IMAD.WIDE.U32 R4, R13, c[0x0][0x188], R14
IMAD.IADD R7, R7, 0x1, R17
IMAD.MOV.U32 R17, RZ, RZ, RZ
IMAD R5, R13, c[0x0][0x18c], R5
LEA.HI.X R7, R6, c[0x0][0x16c], R7, 0x1, P2
SHF.R.U32.HI R19, RZ, 0x2, R16
LOP3.LUT R16, R16, 0x3, RZ, 0xc0, !PT
IMAD.WIDE.U32 R16, R19, R12, R16
LEA R19, P1, R4, c[0x0][0x160], 0x1
LEA.HI.X R5, R4, c[0x0][0x164], R5, 0x1, P1
LEA R18, P1, R16, R19, 0x2
LEA R20, P2, R16, R21, 0x2
LEA.HI.X R19, R16, R5, R17, 0x2, P1
LEA.HI.X R21, R16, R7, R17, 0x2, P2
IMAD.WIDE.U32 R16, R23, 0x10, R18
LDG.E R4, [R18.64]
IMAD.WIDE.U32 R22, R23, 0x10, R20
LDG.E R24, [R20.64]
LDG.E R25, [R20.64+0x10]
LDG.E R6, [R18.64+0x10]
LDG.E R5, [R16.64]
LDG.E R7, [R16.64+0x10]
LDG.E R26, [R22.64]
LDG.E R27, [R22.64+0x10]
WARPSYNC 0xffffffff
HMMA.16816.F16 R8, R4, R24, R8
HMMA.16816.F16 R2, R4, R26, R2
NOP
BSYNC B1
IADD3 R0, P1, R0, 0x10, RZ
IMAD.X R15, RZ, RZ, R15, P1
ISETP.GE.U32.AND P1, PT, R0, c[0x0][0x188], PT
ISETP.GE.U32.AND.EX P1, PT, R15, c[0x0][0x18c], PT, P1
@!P1 BRA 0x7f1eaefbfe90
BSYNC B0
@!P0 EXIT
S2R R4, SR_LANEID
IMAD.MOV.U32 R15, RZ, RZ, c[0x0][0x180]
WARPSYNC 0xffffffff
IMAD.WIDE.U32 R10, R13, c[0x0][0x180], R10
SHF.R.U32.HI R15, RZ, 0x1, R15
IMAD.MOV.U32 R5, RZ, RZ, RZ
LEA R7, P0, R10, c[0x0][0x170], 0x1
IMAD R11, R13, c[0x0][0x184], R11
LEA.HI.X R11, R10, c[0x0][0x174], R11, 0x1, P0
SHF.R.U32.HI R0, RZ, 0x2, R4
LOP3.LUT R4, R4, 0x3, RZ, 0xc0, !PT
IMAD.WIDE.U32 R4, R0, R15, R4
LEA R6, P0, R4, R7, 0x2
LEA.HI.X R7, R4, R11, R5, 0x2, P0
IMAD.WIDE.U32 R4, R15, 0x20, R6
STG.E [R6.64], R8
STG.E [R4.64], R9
STG.E [R6.64+0x10], R2
STG.E [R4.64+0x10], R3
EXIT
BRA 0x7f1eaefc02b0
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP

我们依然主要关注WMMA相关的SASS指令,如下所示。可以发现WMMA161616在底层是通过两个HMMA16816指令实现,同样地,SASS指令也是Nvidia提供的另一种调用Tensor Core的编程方法。

HMMA.16816.F16
Nvidia Tensor Core初探中提到Nvidia提供了四种调用Tensor Core的编程方法,这里提到了三种,还有一种是MMA PTX指令,其中MMA16816 PTX指令底层实现即是HMMA16816指令,后续会在MMA PTX相关文章中提及。

4 其他

4.1 HGEMM优化

学习WMMA API的目标在于调用Tensor Core优化HGEMM,相比于cublas,WMMA的性能究竟如何?
 
 

Nvidia Tensor Core-WMMA API编程入门的更多相关文章

  1. Mysql C语言API编程入门讲解

    原文:Mysql C语言API编程入门讲解 软件开发中我们经常要访问数据库,存取数据,之前已经有网友提出让鸡啄米讲讲数据库编程的知识,本文就详细讲解如何使用Mysql的C语言API进行数据库编程.   ...

  2. Windows API 编程入门

    Windows 工作原理的中心思想就是“动态链接”概念.Windows 自身带有一大套函数,应用程序就是通过调用这些函数 来实现它的用户界面和在屏幕上显示文本和图形的.这些函数都是在动态链接库里实现的 ...

  3. NVIDIA Tensor Cores解析

    NVIDIA Tensor Cores解析 高性能计算机和人工智能前所未有的加速 Tensor Cores支持混合精度计算,动态调整计算以加快吞吐量,同时保持精度.最新一代将这些加速功能扩展到各种工作 ...

  4. NVIDIA深度学习Tensor Core性能解析(下)

    NVIDIA深度学习Tensor Core性能解析(下) DeepBench推理测试之RNN和Sparse GEMM DeepBench的最后一项推理测试是RNN和Sparse GEMM,虽然测试中可 ...

  5. NVIDIA深度学习Tensor Core性能解析(上)

    NVIDIA深度学习Tensor Core性能解析(上) 本篇将通过多项测试来考验Volta架构,利用各种深度学习框架来了解Tensor Core的性能. 很多时候,深度学习这样的新领域会让人难以理解 ...

  6. 转载自~浮云比翼:Step by Step:Linux C多线程编程入门(基本API及多线程的同步与互斥)

    Step by Step:Linux C多线程编程入门(基本API及多线程的同步与互斥)   介绍:什么是线程,线程的优点是什么 线程在Unix系统下,通常被称为轻量级的进程,线程虽然不是进程,但却可 ...

  7. 《ASP.NET Core跨平台开发从入门到实战》Web API自定义格式化protobuf

    <ASP.NET Core跨平台开发从入门到实战>样章节 Web API自定义格式化protobuf. 样章 Protocol Buffers 是一种轻便高效的结构化数据存储格式,可以用于 ...

  8. 初识Django —Python API接口编程入门

    初识Django —Python API接口编程入门 一.WEB架构的简单介绍 Django是什么? Django是一个开放源代码的Web应用框架,由Python写成.我们的目标是用Python语言, ...

  9. Storm编程入门API系列之Storm的Topology多个Workers数目控制实现

    前期博客 Storm编程入门API系列之Storm的Topology默认Workers.默认executors和默认tasks数目 继续编写 StormTopologyMoreWorker.java ...

  10. Storm编程入门API系列之Storm的Topology多个Executors数目控制实现

    前期博客 Storm编程入门API系列之Storm的Topology默认Workers.默认executors和默认tasks数目 Storm编程入门API系列之Storm的Topology多个Wor ...

随机推荐

  1. 设置mode='out-on'导致路由切换过快路由加载报错 Failed to execute 'insertBefore' on 'Node'

    原代码: 解决代码: 原因未知

  2. 高并发解决方案之 mysql悲观锁:select ... for update

    select ... for update 场景:多个进程都先读后写咋办,需要的是让他们串行执行. 比如库存的减少.一般这些操作都是很长一串并且是开启事务的.如果库存刚开始读的时候是1,而立马另一个进 ...

  3. SQL Server 手工 锁表、查询被锁表、解锁相关语句

    SQL Server 手工 锁表.查询被锁表.解锁相关语句 --锁表(其它事务不能读.更新.删除) BEGIN TRAN SELECT * FROM <表名> WITH(TABLOCKX) ...

  4. python Queue(队列学习)

    Python 的Queue模块中提供了同步的.线程安全的队列类,包括FIFO(先入先出)队列Queue,LIFO(后入先出)队列LifoQueue,和优先级队列PriorityQueue.这些队列都实 ...

  5. 循环文件夹汇总所有发票开具Excel文件数据

    'xlsx cnADO.Open "provider=Microsoft.ACE.OLEDB.12.0;extended properties='excel 8.0;hdr=no;imex= ...

  6. Docker安装一些软件

    1.Docker开始远程访问 vim /lib/systemd/system/docker.service 在ExecStart的值最最后面追加:空格+-H tcp://0.0.0.0:2375 sy ...

  7. 搭建Angular基础项目学习

    https://stackblitz.com/借助StackBlitz网站可快速开始搭建一个angular项目 一个angular的component包含三项东西 A component class  ...

  8. python语言linux操作系统oracle环境安装

    金句:如果没把握,最好先Google一下. 1.严格按照 https://oracle.github.io/odpi/doc/installation.html#linux 教程一步步做 包括下载的软 ...

  9. error check

    #define SYSTEM_PRTCT_NOERR 0 #define SYSTEM_PRTCT_COVER (1 << 0) /* */#define SYSTEM_PRTCT_LPH ...

  10. MATLAB默认路径修改

    笔者曾尝试在软件界面的"设置路径"或者Parallel中修改默认路径,但多次尝试均失败.后来经人提点,MATLAB默认文件夹路径可以在桌面图标属性中"起始位置" ...