Hello, I’m trying to substitute my QKV attention function with
torch.nn.functional.scaled_dot_product_attention to benefit from memory efficient attention. However, the result is very slightly different from my own implementation, which mostly follows normal attention calculations.
-Training wise I notice slower convergence and often unstable training when i use sdp, so I’m trying to understand where such differences are coming from
-I checked which method is being invoked and it seems the C++ implementation is the only one that works without erroring out.
-This occurs with or without mixed precision
-I set torch.backends.cudnn.deterministic = True and it did not change the output
My variables are query, key, and value, all of the shape (B,L,E) (1, 32768, 192)
Here is the function I’m using
Here is my older implementation
attn_weight = torch.softmax((query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))), dim=-1)
I expect x to be exactly equal to x2 but there are minor differences in some elements of the output tensor.
Am I understanding the sdp function correctly? I would have no issue if the outputs are very slightly off but it seems that this affects training significantly.