Custom autograd function optimizer

Hello
I built a custom attention module using my custom cuda kernel, which is binded with cdll.

class AttnCUDA(nn.Module):
    def __init__(self, emb_dim):
        super(AttnCUDA, self).__init__()
        attn_cpp = AttnCUDA._compile_module()
        attn_handle = attn_cpp.init(emb_dim)

        class AttnFunction(torch.autograd.Function):

            @staticmethod
            def forward(ctx, query, key, value):
                batch_size = query.shape[0]
                num_heads = query.shape[1]
                num_batches = batch_size * num_heads
                seq_len = query.shape[2]
                emb_dim = query.shape[3]
                tmp_hQuery = query.permute(0, 1, 3, 2).flatten()
                tmp_hValue = value.permute(0, 1, 3, 2).flatten()
                hAttn = torch.empty((batch_size,num_heads,seq_len,seq_len), dtype=torch.float32, device='cuda')
                hAttnScore = torch.empty((batch_size,num_heads,seq_len,seq_len), dtype=torch.float32, device='cuda')
                hOut = torch.empty((batch_size,num_heads,seq_len,emb_dim), dtype=torch.float32, device='cuda')

                hQuery_p = tmp_hQuery.contiguous().data_ptr()
                hKey_p = key.contiguous().data_ptr()
                hValue_p = tmp_hValue.contiguous().data_ptr()
                hAttn_p = hAttn.contiguous().data_ptr()
                hAttnScore_p = hAttnScore.contiguous().data_ptr()
                hOut_p = hOut.contiguous().data_ptr()

                attn_cpp.attn_forward(attn_handle, hQuery_p, hKey_p, hValue_p, hAttn_p, hAttnScore_p, hOut_p, seq_len, emb_dim, num_batches)
                
                ctx.save_for_backward(query, key, value, hAttnScore.view(batch_size, num_heads, seq_len, seq_len).permute(0, 1 ,3, 2))
                out = hOut.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)

                context_layer = out.permute(0, 2, 1, 3).contiguous()
                return context_layer, hAttnScore.view(batch_size, num_heads, seq_len, seq_len).permute(0, 1 ,3, 2)

            @staticmethod
            def backward(ctx, grad_output, grad_weight):
                query, key, value, attn_score = ctx.saved_tensors
                batch_size = query.shape[0]
                num_heads = query.shape[1]
                num_batches = batch_size * num_heads
                seq_len = query.shape[2]
                emb_dim = query.shape[3]
                
                hGradAttnScore = torch.zeros(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
                hGradAttnScoreScale = torch.zeros(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
                hGradAttnScale = torch.empty(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
                hGradAttn = torch.empty(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
                hGradQuery = torch.empty(num_batches*seq_len*emb_dim, dtype=torch.float32, device='cuda')
                hGradKey = torch.empty(num_batches*seq_len*emb_dim, dtype=torch.float32, device='cuda')
                hGradValue = torch.empty(num_batches*seq_len*emb_dim, dtype=torch.float32, device='cuda')
                tmp_hGradOutput = grad_output.permute(0, 1 ,3, 2)
                tmp_key = key.permute(0, 1 ,3, 2).flatten()
                tmp_query = query.permute(0, 1 ,3, 2).flatten()

                hQuery_p = tmp_query.contiguous().data_ptr()
                hKey_p = tmp_key.contiguous().data_ptr()
                hValue_p = value.contiguous().data_ptr()
                hAttnScore_p = attn_score.contiguous().data_ptr()
                hGradOutput_p = tmp_hGradOutput.contiguous().data_ptr()
                hGradAttnScore_p = hGradAttnScore.contiguous().data_ptr()
                hGradAttnScoreScale_p = hGradAttnScoreScale.contiguous().data_ptr()
                hGradAttnScale_p = hGradAttnScale.contiguous().data_ptr()
                hGradAttn_p = hGradAttn.contiguous().data_ptr()
                hGradQuery_p = hGradQuery.contiguous().data_ptr()
                hGradKey_p = hGradKey.contiguous().data_ptr()
                hGradValue_p = hGradValue.contiguous().data_ptr()

                attn_cpp.attn_backward(attn_handle, hQuery_p, hKey_p, hValue_p, hAttnScore_p, hGradOutput_p, hGradAttnScore_p, hGradAttnScoreScale_p, 
                        hGradAttnScale_p, hGradAttn_p, hGradQuery_p, hGradKey_p, hGradValue_p,
                        seq_len, emb_dim, num_batches)

                gradQuery = hGradQuery.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)
                gradKey = hGradKey.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)
                gradValue = hGradValue.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)
                return gradQuery, gradKey, gradValue

        self.attn_func = AttnFunction

    def forward(self, query, key, value):
        return self.attn_func.apply(query, key, value)

    @staticmethod
    def _compile_module():
        attn_cpp = ctypes.CDLL('./attn_new.so')
        attn_cpp.init.argtypes = [c_int]
        attn_cpp.init.restype = ctypes.c_void_p
        attn_cpp.attn_forward.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, 
                                          c_int, c_int, c_int]
        attn_cpp.attn_backward.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, 
                                          ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p,
                                          c_int, c_int, c_int]
        attn_cpp.destroy.argtypes = [ctypes.c_void_p]
        return attn_cpp

like this.

And it does operate correctly. (I verified every value with torch custom code.)
So when I train using this attention module, the output value and the loss of the first iteration are exactly the same as the torch code. But after the first iteration, the values go differently and the model is eventually not trained well. Like it is trained but not like what I expected. When I trained with torch code, the accuracy starts at 82% and ends at 86%. But with the Cuda custom code, it starts at 66% and ends at 80%. I think the reason for this issue is the optimizer. Is there any way I could solve this issue?

+) If I build with torch cuda extension, will this issue can be solved?