In place operation related to scaled_dot_product_attention

Hi, we’ve been developing a large piece of code for a couple of months and recently we ran into in-place operations errors (one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 6, 3584, 64]], which is output 0 of ScaledDotProductEfficientAttentionBackward0, is at version 4; expected version 0 instead.) when we’re trying to run the code on a difference device. This is the first time ever we saw this error. The same piece of code runs fine on the other devices. I did some research, it seems in-place operation errors are very basic and should not change with devices? Any help/input would be appreciated! It’s so puzzling.

If we replace the line x = F.scaled_dot_product_attention(q, k, v) with

scale_factor = 1 / math.sqrt(q.size(-1))
attn_weight = q @ k.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, 0.0, train=True)
x = attn_weight @ v

, the code runs fine too.

Update: the error is gone after we add clone (), x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask2d_padding).clone(). But still, any explanation of what’s going on would be appreciated