For practice, I implemented the forward() and backward() CUDA kernels for GRU, following this tutorial:
https://pytorch.org/tutorials/advanced/cpp_extension.html
However, even though my CUDA implementation is faster than Python code, it still lags behind the PyTorch Library GRU implementation:
Forward pass average time:
Python on GPU: 323
Fused cuda on GPU: 178
Torch library GRU on GPU: 94
I am trying to see whether there are things I can do to make my implementation match the library performance.
Any advice?
My GRU implementation can be found here
Below is my code for the forward pass:
template <typename scalar_t>
__global__ void gru_cuda_forward_kernel(
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_x,
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_h,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> resetgate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> inputgate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> newgate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < gate_x.size(2) && c < gate_h.size(2)) {
resetgate[n][c] = sigmoid(gate_x[n][0][c] + gate_h[n][0][c]);
inputgate[n][c] = sigmoid(gate_x[n][1][c] + gate_h[n][1][c]);
newgate[n][c] = tanh(gate_x[n][2][c] + resetgate[n][c] * gate_h[n][2][c]);
new_h[n][c] = newgate[n][c] + inputgate[n][c] * (old_h[n][c] - newgate[n][c]);
}
}
std::vector<torch::Tensor> gru_cuda_forward(
torch::Tensor input,
torch::Tensor x2h_w,
torch::Tensor h2h_w,
torch::Tensor x2h_b,
torch::Tensor h2h_b,
torch::Tensor old_h) {
auto gate_x_weights = torch::addmm(x2h_b, input, x2h_w.transpose(0, 1));
auto gate_h_weights = torch::addmm(h2h_b, old_h, h2h_w.transpose(0, 1));
const auto batch_size = old_h.size(0);
const auto state_size = old_h.size(1);
auto gate_x = gate_x_weights.reshape({batch_size, 3, state_size});
auto gate_h = gate_h_weights.reshape({batch_size, 3, state_size});
auto resetgate = torch::zeros_like(old_h);
auto inputgate = torch::zeros_like(old_h);
auto newgate = torch::zeros_like(old_h);
auto new_h = torch::zeros_like(old_h);
const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(gate_x.type(), "gru_forward_cuda", ([&] {
gru_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gate_x.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
gate_h.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
resetgate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
inputgate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
newgate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
return {
new_h, resetgate, inputgate, newgate, gate_h
};
}
Thanks.