nn.MultiheadAttention permutation equivalence

Hi, when I used nn.MultiheadAttention with permutated input order, I found that the output are not permutation equivalent. I would really appreciate your comments and help!

below are my test script:

import torch
from torch import nn
embed_dim = 128
num_heads = 4
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first = False)
multihead_attn.out_proj.weight.data = multihead_attn.out_proj.weight.data * 100
multihead_attn.out_proj.bias.data = torch.rand(embed_dim) * 10
x = torch.arange(100)[...,None].repeat(1,128).float()

perturbed_x = x.clone()
perturbed_idx = torch.randperm(100)
perturbed_x = x[perturbed_idx]

print(torch.max(perturbed_x - x[perturbed_idx]), torch.min(perturbed_x - x[perturbed_idx]))

x_in = x.unsqueeze(1)
perturbed_x_in = perturbed_x.unsqueeze(1)
out = multihead_attn(x_in,x_in,x_in)[0]
out1 = multihead_attn(perturbed_x_in,perturbed_x_in,perturbed_x_in)[0]

print(torch.max(out[permute_index] - out1), torch.min(out[permute_index] - out1))