Optimizing GRU CUDA kernel

For practice, I implemented the forward() and backward() CUDA kernels for GRU, following this tutorial:
https://pytorch.org/tutorials/advanced/cpp_extension.html

However, even though my CUDA implementation is faster than Python code, it still lags behind the PyTorch Library GRU implementation:

Forward pass average time:
Python on GPU: 323
Fused cuda on GPU: 178
Torch library GRU on GPU: 94

I am trying to see whether there are things I can do to make my implementation match the library performance.

Any advice?

My GRU implementation can be found here

Below is my code for the forward pass:

template <typename scalar_t>
__global__ void gru_cuda_forward_kernel(
    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_x,
    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_h,
    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_h,
    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> resetgate,
    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> inputgate,
    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> newgate,
    torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h) {
      //batch index
      const int n = blockIdx.y;
      // column index
      const int c = blockIdx.x * blockDim.x + threadIdx.x;

      if (c < gate_x.size(2) && c < gate_h.size(2)) {
        resetgate[n][c] = sigmoid(gate_x[n][0][c] + gate_h[n][0][c]);
        inputgate[n][c] = sigmoid(gate_x[n][1][c] + gate_h[n][1][c]);
        newgate[n][c] = tanh(gate_x[n][2][c] + resetgate[n][c] * gate_h[n][2][c]);
        new_h[n][c] = newgate[n][c] + inputgate[n][c] * (old_h[n][c] - newgate[n][c]);
      }
}

std::vector<torch::Tensor> gru_cuda_forward(
    torch::Tensor input,
    torch::Tensor x2h_w,
    torch::Tensor h2h_w,
    torch::Tensor x2h_b,
    torch::Tensor h2h_b,
    torch::Tensor old_h) {
        auto gate_x_weights = torch::addmm(x2h_b, input, x2h_w.transpose(0, 1));
        auto gate_h_weights = torch::addmm(h2h_b, old_h, h2h_w.transpose(0, 1));

        const auto batch_size = old_h.size(0);
        const auto state_size = old_h.size(1);

        auto gate_x = gate_x_weights.reshape({batch_size, 3, state_size});
        auto gate_h = gate_h_weights.reshape({batch_size, 3, state_size});

        auto resetgate = torch::zeros_like(old_h);
        auto inputgate = torch::zeros_like(old_h);
        auto newgate = torch::zeros_like(old_h);
        auto new_h = torch::zeros_like(old_h);

        const int threads = 1024;
        const dim3 blocks((state_size + threads - 1) / threads, batch_size);

        AT_DISPATCH_FLOATING_TYPES(gate_x.type(), "gru_forward_cuda", ([&] {
            gru_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
                gate_x.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                gate_h.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                old_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
                resetgate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
                inputgate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
                newgate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
                new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
          }));
        

        return {
          new_h, resetgate, inputgate, newgate, gate_h
        };
}

Thanks.