This extension provides features for
- mapping between memory and fragment (primitive functions)
- operationf for vectors
- loading a vector as a fragment
- storing a fragment as a vector
- making eye matrix fragment
- C++ interface for
mma
instructions - Error Correction (TCEC) for SGEMM emulation [detail]
- arithmetic operators for fragments (
+, -, *, /, fma
) [detail] - utils [detail]
- etc
without using extra shared memory.
Caution!!
WMMA API does not have backward compatibility.
Please specify an appropriate virtual architecture for real GPU when you use this library.
For instance, a program which is compiled with -arch=sm_70
may not work correctly on Ampere GPUs.
- CUDA (10.2 or later)
- C++ (17 or later)
- sm_70: ((16, 16, 16), fp16/fp32)
- sm_75: ((16, 16, 16), fp16/fp32)
- sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32)
This function calculates the mapping of the memory and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach<decltype(frag_b)>(
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
const auto m = mem_index % 16;
const auto n = mem_index / 16;
for (unsigned i = 0; i < fragment_index_count; i++)
frag_b.x[frag_index_list[i]] = convert_to<half>(matrix[n * 16 + m]);
});
This function calculates the mapping of the matrix element position (i,j) and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach_ij<decltype(frag_b)>(
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned i, const unsigned j) {
for (unsigned f = 0; f < fragment_index_count; f++)
frag_b.x[frag_index_list[f]] = convert_to<half>(matrix[j * 16 + i]);
});
This function calculates the mapping of a given vector and fragment elements.
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_b)>(
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
for (unsigned i = 0; i < fragment_index_count; i++)
frag_b.x[frag_index_list[i]] = convert_to<half>(vector[mem_index]);
});
// is equivalent to `load_vector`
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_c)>(nvcuda::wmma::mem_col_major,
[&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
for (unsigned i = 0; i < fragment_index_count; i++)
vector[mem_index] = convert_to<compute_t>(frag_c.x[frag_index_list[i]]);
});
// is equivalent to `store_vector`
This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
unsigned tid_list[2];
unsigned fid_list[2];
unsigned list_size;
mtk::wmma::map<decltype(frag_b)>(tid_list, fid_list, list_size, i, j);
for (unsigned k = 0; k < list_size; k++) {
if ((threadIdx.x & 0x1f) == tid_list[k]) {
frag_b.x[fid_list[k]] = 3.0f;
}
}
#include <mma.h>
#include <wmma_extension/wmma_extension.hpp>
__global__ void kernel() {
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> frag_a;
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;
__shared__ float vec16[16];
mtk::wmma::load_vector(frag_a, vec16);
mtk::wmma::load_vector(frag_b, vec16);
nvcuda::wmma::fill_fragment(frag_c, 0.0f);
nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);
mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);
}
- Arguments
- dst_fragment : Destination fragment (
accumulator
) - alpha : diagonal element
- dst_fragment : Destination fragment (
- Argument
- dst_fragment : Destination fragment
This function output the elements of a fragment.
- Arguments
- frag : Target fragment
- name : printing name of fragment (
char*
, optional)
#include <wmma_extension/wmma_mma.hpp>
__global__ void kernel(float* const d, const half* const a, const half* const b, const float* const c) {
mtk::wmma::mma::fragment<nvcuda::wmma::matrix_a , 16, 8, 16, half, nvcuda::wmma::col_major> frag_a;
mtk::wmma::mma::fragment<nvcuda::wmma::matrix_b , 16, 8, 16, half, nvcuda::wmma::col_major> frag_b;
mtk::wmma::mma::fragment<nvcuda::wmma::accumulator, 16, 8, 16, float> frag_c;
mtk::wmma::mma::fragment<nvcuda::wmma::accumulator, 16, 8, 16, float> frag_d;
mtk::wmma::mma::load_matrix_sync(frag_a, a, 16);
mtk::wmma::mma::load_matrix_sync(frag_b, b, 8);
mtk::wmma::mma::load_matrix_sync(frag_c, c, 16, nvcuda::wmma::mem_col_major);
mtk::wmma::mma::mma_sync(frag_d, frag_a, frag_b, frag_c);
mtk::wmma::mma::store_matrix_sync(d, frag_d, 16, nvcuda::wmma::mem_col_major);
}
shape | type | arch |
---|---|---|
m16n8k16 | half |
sm_80 or higher |
m16n8k8 | half |
sm_75 or higher |
m16n8k8 | nvcuda::wmma::tf32 |
sm_80 or higher |
m8n8k4 | half |
sm_70, sm_75 |
foreach
foreach_v
load_matrix_sync
store_matrix_sync
fill_fragment
fill_zero
MIT