Hello, I’m trying to substitute my QKV attention function with
torch.nn.functional.scaled_dot_product_attention to benefit from memory efficient attention. However, the result is very slightly different from my own implementation, which mostly follows normal attention calculations.
-Training wise I notice slower convergence and often unstable training when i use sdp, so I’m trying to understand where such differences are coming from
-I checked which method is being invoked and it seems the C++ implementation is the only one that works without erroring out.
-This occurs with or without mixed precision
-I set torch.backends.cudnn.deterministic = True and it did not change the output
My variables are query, key, and value, all of the shape (B,L,E) (1, 32768, 192)
Here is the function I’m using
x=torch.nn.functional.scaled_dot_product_attention(query,key,value)
Here is my older implementation
attn_weight = torch.softmax((query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))), dim=-1)
x2=attn_weight @value
I expect x to be exactly equal to x2 but there are minor differences in some elements of the output tensor.
Am I understanding the sdp function correctly? I would have no issue if the outputs are very slightly off but it seems that this affects training significantly.
Thank you!