Hello
I built a custom attention module using my custom cuda kernel, which is binded with cdll.
class AttnCUDA(nn.Module):
def __init__(self, emb_dim):
super(AttnCUDA, self).__init__()
attn_cpp = AttnCUDA._compile_module()
attn_handle = attn_cpp.init(emb_dim)
class AttnFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, query, key, value):
batch_size = query.shape[0]
num_heads = query.shape[1]
num_batches = batch_size * num_heads
seq_len = query.shape[2]
emb_dim = query.shape[3]
tmp_hQuery = query.permute(0, 1, 3, 2).flatten()
tmp_hValue = value.permute(0, 1, 3, 2).flatten()
hAttn = torch.empty((batch_size,num_heads,seq_len,seq_len), dtype=torch.float32, device='cuda')
hAttnScore = torch.empty((batch_size,num_heads,seq_len,seq_len), dtype=torch.float32, device='cuda')
hOut = torch.empty((batch_size,num_heads,seq_len,emb_dim), dtype=torch.float32, device='cuda')
hQuery_p = tmp_hQuery.contiguous().data_ptr()
hKey_p = key.contiguous().data_ptr()
hValue_p = tmp_hValue.contiguous().data_ptr()
hAttn_p = hAttn.contiguous().data_ptr()
hAttnScore_p = hAttnScore.contiguous().data_ptr()
hOut_p = hOut.contiguous().data_ptr()
attn_cpp.attn_forward(attn_handle, hQuery_p, hKey_p, hValue_p, hAttn_p, hAttnScore_p, hOut_p, seq_len, emb_dim, num_batches)
ctx.save_for_backward(query, key, value, hAttnScore.view(batch_size, num_heads, seq_len, seq_len).permute(0, 1 ,3, 2))
out = hOut.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)
context_layer = out.permute(0, 2, 1, 3).contiguous()
return context_layer, hAttnScore.view(batch_size, num_heads, seq_len, seq_len).permute(0, 1 ,3, 2)
@staticmethod
def backward(ctx, grad_output, grad_weight):
query, key, value, attn_score = ctx.saved_tensors
batch_size = query.shape[0]
num_heads = query.shape[1]
num_batches = batch_size * num_heads
seq_len = query.shape[2]
emb_dim = query.shape[3]
hGradAttnScore = torch.zeros(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
hGradAttnScoreScale = torch.zeros(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
hGradAttnScale = torch.empty(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
hGradAttn = torch.empty(num_batches*seq_len*seq_len, dtype=torch.float32, device='cuda')
hGradQuery = torch.empty(num_batches*seq_len*emb_dim, dtype=torch.float32, device='cuda')
hGradKey = torch.empty(num_batches*seq_len*emb_dim, dtype=torch.float32, device='cuda')
hGradValue = torch.empty(num_batches*seq_len*emb_dim, dtype=torch.float32, device='cuda')
tmp_hGradOutput = grad_output.permute(0, 1 ,3, 2)
tmp_key = key.permute(0, 1 ,3, 2).flatten()
tmp_query = query.permute(0, 1 ,3, 2).flatten()
hQuery_p = tmp_query.contiguous().data_ptr()
hKey_p = tmp_key.contiguous().data_ptr()
hValue_p = value.contiguous().data_ptr()
hAttnScore_p = attn_score.contiguous().data_ptr()
hGradOutput_p = tmp_hGradOutput.contiguous().data_ptr()
hGradAttnScore_p = hGradAttnScore.contiguous().data_ptr()
hGradAttnScoreScale_p = hGradAttnScoreScale.contiguous().data_ptr()
hGradAttnScale_p = hGradAttnScale.contiguous().data_ptr()
hGradAttn_p = hGradAttn.contiguous().data_ptr()
hGradQuery_p = hGradQuery.contiguous().data_ptr()
hGradKey_p = hGradKey.contiguous().data_ptr()
hGradValue_p = hGradValue.contiguous().data_ptr()
attn_cpp.attn_backward(attn_handle, hQuery_p, hKey_p, hValue_p, hAttnScore_p, hGradOutput_p, hGradAttnScore_p, hGradAttnScoreScale_p,
hGradAttnScale_p, hGradAttn_p, hGradQuery_p, hGradKey_p, hGradValue_p,
seq_len, emb_dim, num_batches)
gradQuery = hGradQuery.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)
gradKey = hGradKey.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)
gradValue = hGradValue.view(batch_size, num_heads, emb_dim, seq_len).permute(0, 1 ,3, 2)
return gradQuery, gradKey, gradValue
self.attn_func = AttnFunction
def forward(self, query, key, value):
return self.attn_func.apply(query, key, value)
@staticmethod
def _compile_module():
attn_cpp = ctypes.CDLL('./attn_new.so')
attn_cpp.init.argtypes = [c_int]
attn_cpp.init.restype = ctypes.c_void_p
attn_cpp.attn_forward.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p,
c_int, c_int, c_int]
attn_cpp.attn_backward.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p,
c_int, c_int, c_int]
attn_cpp.destroy.argtypes = [ctypes.c_void_p]
return attn_cpp
like this.
And it does operate correctly. (I verified every value with torch custom code.)
So when I train using this attention module, the output value and the loss of the first iteration are exactly the same as the torch code. But after the first iteration, the values go differently and the model is eventually not trained well. Like it is trained but not like what I expected. When I trained with torch code, the accuracy starts at 82% and ends at 86%. But with the Cuda custom code, it starts at 66% and ends at 80%. I think the reason for this issue is the optimizer. Is there any way I could solve this issue?
+) If I build with torch cuda extension, will this issue can be solved?