Multi-Head Self Attention in Transformer is permutation-invariant or equivariant how to see it in practice?

I read that a function f is equivariant if f(P(x)) = P(f(x)) where P is a permutation

So to check what means equivariant and permutation invariant I wrote the following code

import torch
import torch.nn as nn

multihead_attn = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x0 = torch.ones(11,32)
x1 = torch.ones(11,32)
for i in range(x0.size(0)):
    x0[i] *= i
    x1[i] *= (i+1) % x0.size(0)

x = torch.cat(
    (x0.unsqueeze(0), x1.unsqueeze(0))
    )

y0, y1 = multihead_attn(x,x,x)[0]
y0 = y0.squeeze(0)
y1 = y1.squeeze(0)

Then to check torch.equal(x0[1],x1[0]) >>> True

but torch.equal(y0,y1) >>> False, so doesn’t seem to be permutation invariant
and torch.equal(y0[1],y1[0]) >>> False, so doesn’t seem to be equivarient

So what am I doing wrong ?

The issue is that you are using torch.equal on a float tensor. You can use
torch.allclose(y0[1], y1[0], atol=1e-6) instead (which evaluates to True).
I am not entirely sure what you are trying to do here but note that
print(torch.equal(x0[0], x1[1]))
is False as well.
Hope that helps

@paulk
Thanks a lot, yes that perfectly answers my question. It shows that multi head attention is equivariant and not permutation invariant as I often read.

I do not understand why, but on my computer print(torch.equal(x0[0], x1[1]))is True

Glad it helped. Not sure if this is relevant then but when I run:

import torch
import torch.nn as nn

multihead_attn = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x0 = torch.ones(11,32)
x1 = torch.ones(11,32)
for i in range(x0.size(0)):
    x0[i] *= i
    x1[i] *= (i+1) % x0.size(0)

x = torch.cat(
    (x0.unsqueeze(0), x1.unsqueeze(0))
    )

y0, y1 = multihead_attn(x,x,x)[0]
y0 = y0.squeeze(0)
y1 = y1.squeeze(0)

print(torch.equal(x0[0], x1[1]))

It prints False while

print(torch.equal(x0[1], x1[0]))

prints True.
But that makes sense and your rotation works fine (for some reason I assumed x had shape [2, 32] instead of [11, 32] and thus thought torch.equal(x0[0], x1[1]) should be True as well.)

TLDR:
Don’t worry about it, my bad.

You are totaly right. It was a type error :smile:
Many thanks

1 Like