I wrote a simple CUDA matrix multiplication kernel:
template <typename scalar_t>
__global__ void matmul_cuda_kernel(
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits> a,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits> b,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits> out,
const int in_size,
const int a_size,
const int b_size
) {
const int col = blockIdx.y * blockDim.y + threadIdx.y;
const int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < a_size && col < b_size) {
scalar_t val = 0;
for (int i = 0; i < in_size; ++i) {
scalar_t a_f = a[row][i];
scalar_t b_f = b[i][col];
val += a_f * b_f;
}
out[row][col] = val;
}
}
Tensors are dispatched with:
AT_DISPATCH_FLOATING_TYPES_AND_HALF(a.type(), "matmul_cuda", ([&] {
matmul_cuda_kernel<scalar_t><<<blocks, threads_per_block>>>(
a.packed_accessor<scalar_t,2,torch::RestrictPtrTraits>(),
b.packed_accessor<scalar_t,2,torch::RestrictPtrTraits>(),
out.packed_accessor<scalar_t,2,torch::RestrictPtrTraits>(),
in_size, a_size, b_size);
} ) );
When I test the accuracy of the kernel and compare its output with F.linear(a, b, None)
on random tensors, I get no deviation for torch.float
, around 1e-12 for torch.double,
but for torch.float16
deviation reaches 1%. Obviously, pytorch MM implementation differs from mine, but why is the difference so high, compared to float and double?