Multi head attention gives wrong attention weights when model is set to .eval()

I am using nn.MultiHeadAttention in my transformer layer and I want to visualize the attention on my tokens. However, whenever, i set model to model.eval(), multihead attention outputs a matrix of constant values. However, whenever, i set model to model.train(), it outputs something different that is not constant. I am not sure why this weird behaviour is showing up. I have tested it with pytorch 1.7 and pytorch 1.11. Looking forward to an answer.

I’m not sure I can reproduce the issue from the current description:

>>> import torch
>>> model = torch.nn.MultiheadAttention(256, 8)
>>> model = model.cuda()
>>> inp = torch.randn(128, 4, 256, device='cuda')
>>> out = model(inp, inp, inp)
>>> model = model.eval()
>>> out2 = model(inp, inp, inp)
>>> torch.allclose(out[0], out2[0])
>>> torch.allclose(out[1], out2[1])
>>> model = model.train()
>>> out3 = model(inp, inp, inp)
>>> torch.allclose(out[0], out3[0])
>>> torch.allclose(out[1], out3[1])

Do you have a minimal code snippet that reproduces the issue?