Custom CUDA code utilizing tensor cores

I want custom a cuda matrix multiplication using tensor cores in PyTorch. But it doesn’t work when compling the operator. The source code was refered to the sample code provided by NVIDIA which act normally on my machine.

The error message is mainly

FAILED: /home/hemeng/Desktop/extension-cpp-master/cuda/build/temp.linux-x86_64-3.9/matmul_cuda_kernel.o 
/usr/local/cuda-11.3/bin/nvcc  -I/home/hemeng/.local/lib/python3.9/site-packages/torch/include -I/home/hemeng/.local/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/hemeng/.local/lib/python3.9/site-packages/torch/include/TH -I/home/hemeng/.local/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/usr/include/python3.9 -c -c /home/hemeng/Desktop/extension-cpp-master/cuda/matmul_cuda_kernel.cu -o /home/hemeng/Desktop/extension-cpp-master/cuda/build/temp.linux-x86_64-3.9/matmul_cuda_kernel.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=matmul_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++14
/home/hemeng/Desktop/extension-cpp-master/cuda/matmul_cuda_kernel.cu(127): error: incomplete type is not allowed

/home/hemeng/Desktop/extension-cpp-master/cuda/matmul_cuda_kernel.cu(128): error: incomplete type is not allowed

2 errors detected in the compilation of "/home/hemeng/Desktop/extension-cpp-master/cuda/matmul_cuda_kernel.cu".

and the kernel implementation is mainly (haven’t use scalar_t in order to locate the problem)

using namespace nvcuda;
// template <typename scalar_t>
__global__ void simple_wmma_gemm(float *a, float *b, float *c, int m_ld,
                                 int n_ld, int k_ld, int num_bh)
{
  // Leading dimensions. Packed with no transpositions.
  int lda = k_ld;
  int ldb = k_ld;
  int ldc = n_ld;

  // Tile using a 2D grid
  int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

  // Declare the fragments
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, float, wmma::row_major> a_frag;
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, float, wmma::col_major> b_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

  wmma::fill_fragment(c_frag, 0.0f);

  // Loop over k
  for (int i = 0; i < k_ld; i += WMMA_K)
  {
    int aCol = i;
    int aRow = warpM * WMMA_M;
    int bCol = warpN * N;
    int bRow = i;

    // Bounds checking
    if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld)
    {
      // Load the inputs
      wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda);
      wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);

      // Perform the matrix multiplication
      wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
    }
  }

  // Load in the current value of c, scale it by beta, and add this our result
  // scaled by alpha
  int cCol = warpN * WMMA_N;
  int cRow = warpM * WMMA_M;

  if (cRow < m_ld && cCol < n_ld)
  {
    // Store the output
    wmma::store_matrix_sync(c + cCol + cRow * ldc, c_frag, ldc,
                            wmma::mem_row_major);
  }
}

May be the problem is the compiler could not recognize the “wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, float, wmma::row_major>” type ?

The full code is here

#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <mma.h>

#include <vector>

#define TILE_DIM 64

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

#define M 16
#define N 16
#define K 16

// namespace
// {
template <typename scalar_t>
__global__ void matmul_cuda_1d_kernel(
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> mat_a,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> mat_b,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> mat_c,
    int m)
{
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < m)
  {
    mat_c[index] = mat_a[index] * mat_b[index];
  }
}

template <typename scalar_t>
__global__ void matmul_cuda_2d_kernel(
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> mat_a,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> mat_b,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> mat_c,
    int m, int n, int k)
{
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  int col = blockIdx.x * blockDim.x + threadIdx.x;
  if (row < m && col < k)
  {
    float Pvalue = 0;
    for (int i = 0; i < n; i++)
    {
      Pvalue += mat_a[row][i] * mat_b[i][col];
    }
    mat_c[row][col] = Pvalue;
  }
}

template <typename scalar_t>
__global__ void matmul_cuda_4d_kernel(
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> mat_a,
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> mat_b,
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> mat_c,
    int m, int n, int k)
{
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  int col = blockIdx.x * blockDim.x + threadIdx.x;
  for (size_t batch = 0; batch < mat_a.size(0); batch++)
  {
    for (size_t head = 0; head < mat_a.size(1); head++)
    {
      if (row < m && col < k)
      {
        float Pvalue = 0;
        for (int i = 0; i < n; i++)
        {
          Pvalue += mat_a[batch][head][row][i] * mat_b[batch][head][i][col];
        }
        mat_c[batch][head][row][col] = Pvalue;
      }
    }
  }
}

template <typename scalar_t>
__global__ void shared_matmul_cuda_4d_kernel(
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> mat_a,
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> mat_b,
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> mat_c,
    int m, int n, int k)
{
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  int col = blockIdx.x * blockDim.x + threadIdx.x;
  __shared__ float aTile[TILE_DIM][TILE_DIM];
  __shared__ float bTile[TILE_DIM][TILE_DIM];
  for (size_t batch = 0; batch < mat_a.size(0); batch++)
  {
    for (size_t head = 0; head < mat_a.size(1); head++)
    {
      if (row < m && col < k)
      {
        aTile[threadIdx.y][threadIdx.x] = mat_a[batch][head][row][threadIdx.x];
        bTile[threadIdx.y][threadIdx.x] = mat_b[batch][head][threadIdx.y][col];
        __syncwarp();
        float Pvalue = 0;
        for (int i = 0; i < n; i++)
        {
          Pvalue += aTile[threadIdx.y][i] * bTile[i][threadIdx.x];
        }
        mat_c[batch][head][row][col] = Pvalue;
      }
    }
  }
}
using namespace nvcuda;
// template <typename scalar_t>
__global__ void simple_wmma_gemm(float *a, float *b, float *c, int m_ld,
                                 int n_ld, int k_ld, int num_bh)
{
  // Leading dimensions. Packed with no transpositions.
  int lda = k_ld;
  int ldb = k_ld;
  int ldc = n_ld;

  // Tile using a 2D grid
  int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
  int warpN = (blockIdx.y * blockDim.y + threadIdx.y);

  // Declare the fragments
  wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, float, wmma::row_major> a_frag;
  wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, float, wmma::col_major> b_frag;
  wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

  wmma::fill_fragment(c_frag, 0.0f);

  // Loop over k
  for (int i = 0; i < k_ld; i += WMMA_K)
  {
    int aCol = i;
    int aRow = warpM * WMMA_M;
    int bCol = warpN * N;
    int bRow = i;

    // Bounds checking
    if (aRow < m_ld && aCol < k_ld && bRow < k_ld && bCol < n_ld)
    {
      // Load the inputs
      wmma::load_matrix_sync(a_frag, a + aCol + aRow * lda, lda);
      wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);

      // Perform the matrix multiplication
      wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
    }
  }

  // Load in the current value of c, scale it by beta, and add this our result
  // scaled by alpha
  int cCol = warpN * WMMA_N;
  int cRow = warpM * WMMA_M;

  if (cRow < m_ld && cCol < n_ld)
  {
    // Store the output
    wmma::store_matrix_sync(c + cCol + cRow * ldc, c_frag, ldc,
                            wmma::mem_row_major);
  }
}
// } // namespace

torch::Tensor matmul_cuda_1d(
    torch::Tensor mat_a,
    torch::Tensor mat_b)
{
  auto mat_c = torch::zeros_like(mat_a);
  int m = mat_a.size(0);

  int threads = 1024;
  int blocks = (m + threads - 1) / threads;

  AT_DISPATCH_FLOATING_TYPES(mat_a.type(), "matmul_cuda_1d", ([&]
                                                              { matmul_cuda_1d_kernel<scalar_t><<<blocks, threads>>>(
                                                                    mat_a.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_b.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_c.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
                                                                    m); }));
  return mat_c;
}

torch::Tensor matmul_cuda_2d(
    torch::Tensor mat_a,
    torch::Tensor mat_b)
{
  int m = mat_a.size(0);
  int n = mat_a.size(1);
  int k = mat_b.size(1);
  auto mat_c = torch::zeros({m, k});

  int dimx = 32;
  int dimy = 32;
  dim3 block(dimx, dimy);
  dim3 grid((m + block.x - 1) / block.x, (k + block.y - 1) / block.y);

  AT_DISPATCH_FLOATING_TYPES(mat_a.type(), "matmul_cuda_2d", ([&]
                                                              { matmul_cuda_2d_kernel<scalar_t><<<grid, block>>>(
                                                                    mat_a.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_b.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_c.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
                                                                    m, n, k); }));
  return mat_c;
}

torch::Tensor matmul_cuda_4d(
    torch::Tensor mat_a,
    torch::Tensor mat_b)
{
  // A: m*n    B: n*k
  int m = mat_a.size(2);
  int n = mat_a.size(3);
  int k = mat_b.size(3);
  int batchs = mat_a.size(0);
  int heads = mat_a.size(1);

  auto options = torch::TensorOptions().device(torch::kCUDA, 0);

  auto mat_c = torch::zeros({batchs, heads, m, k}, options);

  // auto ac_mat_a = mat_a.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();
  // auto ac_mat_b = mat_b.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();
  // auto ac_mat_c = mat_c.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();

  int dimx = 32;
  int dimy = 32;
  dim3 block(dimx, dimy);
  dim3 grid((m + block.x - 1) / block.x, (k + block.y - 1) / block.y);

  AT_DISPATCH_FLOATING_TYPES(mat_a.type(), "matmul_cuda_4d", ([&]
                                                              { matmul_cuda_4d_kernel<scalar_t><<<grid, block>>>(
                                                                    mat_a.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_b.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_c.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
                                                                    m, n, k); }));

  return mat_c;
}

torch::Tensor opt_matmul_cuda_4d(
    torch::Tensor mat_a,
    torch::Tensor mat_b)
{
  // A: m*n    B: n*k
  int m = mat_a.size(2);
  int n = mat_a.size(3);
  int k = mat_b.size(3);
  int batchs = mat_a.size(0);
  int heads = mat_a.size(1);

  auto options = torch::TensorOptions().device(torch::kCUDA, 0);

  auto mat_c = torch::zeros({batchs, heads, m, k}, options);

  // auto ac_mat_a = mat_a.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();
  // auto ac_mat_b = mat_b.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();
  // auto ac_mat_c = mat_c.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();

  int dimx = 32;
  int dimy = 32;
  dim3 block(dimx, dimy);
  dim3 grid((m + block.x - 1) / block.x, (k + block.y - 1) / block.y);

  AT_DISPATCH_FLOATING_TYPES(mat_a.type(), "matmul_cuda_4d", ([&]
                                                              { shared_matmul_cuda_4d_kernel<scalar_t><<<grid, block>>>(
                                                                    mat_a.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_b.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
                                                                    mat_c.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
                                                                    m, n, k); }));

  return mat_c;
}

torch::Tensor tensor_core_matmul_cuda_4d(
    torch::Tensor mat_a,
    torch::Tensor mat_b)
{
  // A: m*n    B: n*k
  int m = mat_a.size(2);
  int n = mat_a.size(3);
  int k = mat_b.size(3);
  int batchs = mat_a.size(0);
  int heads = mat_a.size(1);

  auto options = torch::TensorOptions().device(torch::kCUDA, 0);

  auto mat_c = torch::zeros({batchs, heads, m, k}, options);

  // auto ac_mat_a = mat_a.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();
  // auto ac_mat_b = mat_b.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();
  // auto ac_mat_c = mat_c.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>();

  int dimx = 32;
  int dimy = 32;
  dim3 block(dimx, dimy);
  dim3 grid((m + block.x - 1) / block.x, (k + block.y - 1) / block.y);

  // AT_DISPATCH_FLOATING_TYPES(mat_a.type(), "tensor_core_matmul_cuda_4d", ([&]
  //                                                                         { simple_wmma_gemm<float><<<grid, block>>>(
  //                                                                               mat_a.data<float>(),
  //                                                                               mat_b.data<float>(),
  //                                                                               mat_c.data<float>(),
  //                                                                               m, n, k, 0); }));
  simple_wmma_gemm<<<grid, block>>>(
      mat_a.data<float>(),
      mat_b.data<float>(),
      mat_c.data<float>(),
      m, n, k, 0);
  // AT_DISPATCH_FLOATING_TYPES(mat_a.type(), "tensor_core_matmul_cuda_4d", ([&]
  //                                                                         { simple_wmma_gemm<scalar_t><<<grid, block>>>(
  //                                                                               mat_a.data<scalar_t>(),
  //                                                                               mat_b.data<scalar_t>(),
  //                                                                               mat_c.data<scalar_t>(),
  //                                                                               m, n, k, 0); }));
  // AT_DISPATCH_FLOATING_TYPES(mat_a.type(), "matmul_cuda_4d", ([&]
  //                                                             { shared_matmul_cuda_4d_kernel<scalar_t><<<grid, block>>>(
  //                                                                   mat_a.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
  //                                                                   mat_b.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
  //                                                                   mat_c.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
  //                                                                   m, n, k); }));

  return mat_c;
}

I just learned how to develop custom cuda operator in PyTorch and it’s just SO FRUSTRATED to stuck in the BUG.

Any help will be appreciated!!!

OH MY GOSH!!! I solve it.

The problem is still my inexperience.

The definition of wmma::fragment is

  // 
  // Fragment template
  // 
  template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment;

  // 
  // Fragments for 16x16x16
  // 
  template<> class fragment<matrix_a, 16, 16, 16, __half, row_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_a, 16, 16, 16, __half, col_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_b, 16, 16, 16, __half, row_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_b, 16, 16, 16, __half, col_major> : public __frag_base<__half, 16> {};
  template<> class fragment<accumulator, 16, 16, 16, __half> : public __frag_base<__half, 8> {};
  template<> class fragment<accumulator, 16, 16, 16, float> : public __frag_base<float, 8> {};

#ifdef __CUDA_IMMA__
  template<> class fragment<matrix_a, 16, 16, 16, signed char, row_major> : public __frag_base<signed char, 8> {};
  template<> class fragment<matrix_a, 16, 16, 16, signed char, col_major> : public __frag_base<signed char, 8> {};
  template<> class fragment<matrix_a, 16, 16, 16, unsigned char, row_major> : public __frag_base<unsigned char, 8> {};
  template<> class fragment<matrix_a, 16, 16, 16, unsigned char, col_major> : public __frag_base<unsigned char, 8> {};
  template<> class fragment<matrix_b, 16, 16, 16, signed char, row_major> : public __frag_base<signed char, 8> {};
  template<> class fragment<matrix_b, 16, 16, 16, signed char, col_major> : public __frag_base<signed char, 8> {};  
  template<> class fragment<matrix_b, 16, 16, 16, unsigned char, row_major> : public __frag_base<unsigned char, 8> {};
  template<> class fragment<matrix_b, 16, 16, 16, unsigned char, col_major> : public __frag_base<unsigned char, 8> {};  
  template<> class fragment<accumulator, 16, 16, 16, int> : public __frag_base<int, 8> {};
#endif  /* __CUDA_IMMA__ */

#ifdef __CUDA_AMPERE_MMA__
  template<> class fragment<matrix_a, 16, 16, 16, __nv_bfloat16, row_major> : public __frag_base<__nv_bfloat16, 8> {};
  template<> class fragment<matrix_a, 16, 16, 16, __nv_bfloat16, col_major> : public __frag_base<__nv_bfloat16, 8> {};
  template<> class fragment<matrix_b, 16, 16, 16, __nv_bfloat16, row_major> : public __frag_base<__nv_bfloat16, 8> {};
  template<> class fragment<matrix_b, 16, 16, 16, __nv_bfloat16, col_major> : public __frag_base<__nv_bfloat16, 8> {};
#endif  /* __CUDA_AMPERE_MMA__ */
  
  // 
  // Fragments for 32x8x16
  // 
  template<> class fragment<matrix_a, 32, 8, 16, __half, row_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_a, 32, 8, 16, __half, col_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_b, 32, 8, 16, __half, row_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_b, 32, 8, 16, __half, col_major> : public __frag_base<__half, 16> {};
  template<> class fragment<accumulator, 32, 8, 16, __half> : public __frag_base<__half, 8> {};
  template<> class fragment<accumulator, 32, 8, 16, float> : public __frag_base<float, 8> {};

#ifdef __CUDA_IMMA__
  template<> class fragment<matrix_a, 32, 8, 16, signed char, row_major> : public __frag_base<signed char, 16> {};
  template<> class fragment<matrix_a, 32, 8, 16, signed char, col_major> : public __frag_base<signed char, 16> {};
  template<> class fragment<matrix_a, 32, 8, 16, unsigned char, row_major> : public __frag_base<unsigned char, 16> {};
  template<> class fragment<matrix_a, 32, 8, 16, unsigned char, col_major> : public __frag_base<unsigned char, 16> {};
  template<> class fragment<matrix_b, 32, 8, 16, signed char, row_major> : public __frag_base<signed char, 4> {};
  template<> class fragment<matrix_b, 32, 8, 16, signed char, col_major> : public __frag_base<signed char, 4> {};
  template<> class fragment<matrix_b, 32, 8, 16, unsigned char, row_major> : public __frag_base<unsigned char, 4> {};
  template<> class fragment<matrix_b, 32, 8, 16, unsigned char, col_major> : public __frag_base<unsigned char, 4> {};
  template<> class fragment<accumulator, 32, 8, 16, int> : public __frag_base<int, 8> {};
#endif  /* __CUDA_IMMA__ */

#ifdef __CUDA_AMPERE_MMA__
  template<> class fragment<matrix_a, 32, 8, 16, __nv_bfloat16, row_major> : public __frag_base<__nv_bfloat16, 16> {};
  template<> class fragment<matrix_a, 32, 8, 16, __nv_bfloat16, col_major> : public __frag_base<__nv_bfloat16, 16> {};
  template<> class fragment<matrix_b, 32, 8, 16, __nv_bfloat16, row_major> : public __frag_base<__nv_bfloat16, 4> {};
  template<> class fragment<matrix_b, 32, 8, 16, __nv_bfloat16, col_major> : public __frag_base<__nv_bfloat16, 4> {};
#endif  /* __CUDA_AMPERE_MMA__ */
  
  // 
  // Fragments for 8x32x16
  // 
  template<> class fragment<matrix_a, 8, 32, 16, __half, row_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_a, 8, 32, 16, __half, col_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_b, 8, 32, 16, __half, row_major> : public __frag_base<__half, 16> {};
  template<> class fragment<matrix_b, 8, 32, 16, __half, col_major> : public __frag_base<__half, 16> {};
  template<> class fragment<accumulator, 8, 32, 16, __half> : public __frag_base<__half, 8> {};
  template<> class fragment<accumulator, 8, 32, 16, float> : public __frag_base<float, 8> {};

#ifdef __CUDA_IMMA__
  template<> class fragment<matrix_a, 8, 32, 16, signed char, row_major> : public __frag_base<signed char, 4> {};
  template<> class fragment<matrix_a, 8, 32, 16, signed char, col_major> : public __frag_base<signed char, 4> {};
  template<> class fragment<matrix_a, 8, 32, 16, unsigned char, row_major> : public __frag_base<unsigned char, 4> {};
  template<> class fragment<matrix_a, 8, 32, 16, unsigned char, col_major> : public __frag_base<unsigned char, 4> {};
  template<> class fragment<matrix_b, 8, 32, 16, signed char, row_major> : public __frag_base<signed char, 16> {};
  template<> class fragment<matrix_b, 8, 32, 16, signed char, col_major> : public __frag_base<signed char, 16> {};
  template<> class fragment<matrix_b, 8, 32, 16, unsigned char, row_major> : public __frag_base<unsigned char, 16> {};
  template<> class fragment<matrix_b, 8, 32, 16, unsigned char, col_major> : public __frag_base<unsigned char, 16> {};
  template<> class fragment<accumulator, 8, 32, 16, int> : public __frag_base<int, 8> {};
#endif  /* __CUDA_IMMA__ */

#ifdef __CUDA_AMPERE_MMA__
  template<> class fragment<matrix_a, 8, 32, 16, __nv_bfloat16, row_major> : public __frag_base<__nv_bfloat16, 4> {};
  template<> class fragment<matrix_a, 8, 32, 16, __nv_bfloat16, col_major> : public __frag_base<__nv_bfloat16, 4> {};
  template<> class fragment<matrix_b, 8, 32, 16, __nv_bfloat16, row_major> : public __frag_base<__nv_bfloat16, 16> {};
  template<> class fragment<matrix_b, 8, 32, 16, __nv_bfloat16, col_major> : public __frag_base<__nv_bfloat16, 16> {};
#endif  /* __CUDA_AMPERE_MMA__ */  
  
#ifdef __CUDA_SUBBYTE_IMMA__
  // 
  // Fragments for 8x8x32
  // 
  template<> class fragment<matrix_a, 8, 8, 32, experimental::precision::u4, row_major> : public __frag_base<experimental::precision::u4, 8, 1> {};
  template<> class fragment<matrix_a, 8, 8, 32, experimental::precision::s4, row_major> : public __frag_base<experimental::precision::s4, 8, 1> {};
  template<> class fragment<matrix_b, 8, 8, 32, experimental::precision::u4, col_major> : public __frag_base<experimental::precision::u4, 8, 1> {};
  template<> class fragment<matrix_b, 8, 8, 32, experimental::precision::s4, col_major> : public __frag_base<experimental::precision::s4, 8, 1> {};
  template<> class fragment<accumulator, 8, 8, 32, int> : public __frag_base<int, 2> {};

  // 
  // Fragments for 8x8x128
  // 
  template<> class fragment<matrix_a, 8, 8, 128, experimental::precision::b1, row_major> : public __frag_base<experimental::precision::b1, 32, 1> {};
  template<> class fragment<matrix_b, 8, 8, 128, experimental::precision::b1, col_major> : public __frag_base<experimental::precision::b1, 32, 1> {};
  template<> class fragment<accumulator, 8, 8, 128, int> : public __frag_base<int, 2> {};
#endif  /* __CUDA_SUBBYTE_IMMA__ */

#ifdef __CUDA_AMPERE_MMA__
  //
  // Fragments for 16x16x8
  //
  template<> class fragment<matrix_a, 16, 16, 8, precision::tf32, row_major> : public __frag_base<precision::tf32, 4> {};
  template<> class fragment<matrix_a, 16, 16, 8, precision::tf32, col_major> : public __frag_base<precision::tf32, 4> {};
  template<> class fragment<matrix_b, 16, 16, 8, precision::tf32, row_major> : public __frag_base<precision::tf32, 4> {};
  template<> class fragment<matrix_b, 16, 16, 8, precision::tf32, col_major> : public __frag_base<precision::tf32, 4> {};
  template<> class fragment<accumulator, 16, 16, 8, float> : public __frag_base<float, 8> {};
  
  //
  // Fragments for 8x8x4
  //
  template<> class fragment<matrix_a, 8, 8, 4, double, row_major> : public __frag_base<double, 1> {};
  template<> class fragment<matrix_a, 8, 8, 4, double, col_major> : public __frag_base<double, 1> {};
  template<> class fragment<matrix_b, 8, 8, 4, double, row_major> : public __frag_base<double, 1> {};
  template<> class fragment<matrix_b, 8, 8, 4, double, col_major> : public __frag_base<double, 1> {};
  template<> class fragment<accumulator, 8, 8, 4, double> : public __frag_base<double, 2> {};
#endif  /* __CUDA_AMPERE_MMA__ */  

So wmma::fragment<wmma::matrix_a, 16, 16, 16, float, wmma::row_major> a_frag; should not be exist in the code.

After re-set the size and precision of wmma::fragment, everything is perfect!

Hi do you have a complete code snippet for reference somewhere? Thanks a lot

So sorry for the late reply.

The full code is already posted above:

It’s kind like the lltm_cuda.cpp in the link below.

More detailed information could be found in:

CUSTOM C++ AND CUDA EXTENSIONS.