Cuda kernel 'volta_sgemm_128x32_nn' means what?

I am studying the nvidia torch matmul function.

### variable creation
a = torch.randn(size=(1,128,3),dtype=torch.float32).to(cuda)
b = torch.randn(size=(1,3,32),dtype=torch.float32).to(cuda)

### execution
c = torch.matmul(a,b)

I profiled this code using pyprof and this gives me the result below.

I cannot understand many things in there.

  1. what is sgemm_128_32 means? I see the ‘s’ in sgemm stands for single precision and ‘gemm’ means general matrix multiplication. But i don’t know the 128_32 means. My output matrix dimension is 128 by 32. But I know that cutlass optimizes the sgemm using outer product. (i will give you the link, ref 1) Actually i cannot understand the link.

(1)Does 128_32 means simply the output matrix’s dimension? (2)Is there any way how my output matrix(c, in my code) is actually calculated? (for example, there are total 128*32 threads. And each thread is responsible for one output element using inner product way)

  1. Why the Grid and Block have 3 dimension each and how the grid and block is used for sgemm_128_32? Grid consists of x, y, z. And Block consists of x, y, z. (1) Why do you need 3 dimension? I see that (in the picture above) block X has 256 thread. (2) is this true? And Grid Y is 4. so this means that there is 4 blocks in Grid Y. (3) is this true?
  2. By using that pyprof result, can i figure out how many SMs are used? how many warps are activated in that SM?

Thank you.

ref 1 : CUTLASS: Fast Linear Algebra in CUDA C++ | NVIDIA Developer Blog

As you’ve already partly exaplained:

  • volta: GPU architecture
  • s = accumulator type: single precision in this case
  • gemm: kernel type: matrix multiplication in this case
  • 128: number of elements per CTA in M dimension of the C matrix
  • 32: number of elements per CTA in N dimension of the C matrix
  • nn: storage mode for A and B matrices, respectively: “normal” or “no-transpose” (column-major) in your case

You could probably use Nsight Compute to profile this kernel in more detail.


Thank you for the reply @ptrblck

I cannot understand the meaning of " 128: number of elements per CTA in M dimension of the C matrix".

CTA means the thread block. The thread block(CTA) has 128 elements because of ‘M’ dimension of the C matrix.

Do you mean ‘128 elements’ as 128 threads in the thread block(CTA)?
So are there total 128*32 threads in the CTA?

Thank you.