I read that a function f is equivariant if f(P(x)) = P(f(x)) where P is a permutation
So to check what means equivariant and permutation invariant I wrote the following code
import torch
import torch.nn as nn
multihead_attn = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x0 = torch.ones(11,32)
x1 = torch.ones(11,32)
for i in range(x0.size(0)):
x0[i] *= i
x1[i] *= (i+1) % x0.size(0)
x = torch.cat(
(x0.unsqueeze(0), x1.unsqueeze(0))
)
y0, y1 = multihead_attn(x,x,x)[0]
y0 = y0.squeeze(0)
y1 = y1.squeeze(0)
Then to check torch.equal(x0[1],x1[0]) >>> True
but torch.equal(y0,y1) >>> False, so doesn’t seem to be permutation invariant
and torch.equal(y0[1],y1[0]) >>> False, so doesn’t seem to be equivarient
The issue is that you are using torch.equal on a float tensor. You can use torch.allclose(y0[1], y1[0], atol=1e-6) instead (which evaluates to True).
I am not entirely sure what you are trying to do here but note that print(torch.equal(x0[0], x1[1]))
is False as well.
Hope that helps
@paulk
Thanks a lot, yes that perfectly answers my question. It shows that multi head attention is equivariant and not permutation invariant as I often read.
I do not understand why, but on my computer print(torch.equal(x0[0], x1[1]))is True
Glad it helped. Not sure if this is relevant then but when I run:
import torch
import torch.nn as nn
multihead_attn = nn.MultiheadAttention(embed_dim=32, num_heads=4, batch_first=True)
x0 = torch.ones(11,32)
x1 = torch.ones(11,32)
for i in range(x0.size(0)):
x0[i] *= i
x1[i] *= (i+1) % x0.size(0)
x = torch.cat(
(x0.unsqueeze(0), x1.unsqueeze(0))
)
y0, y1 = multihead_attn(x,x,x)[0]
y0 = y0.squeeze(0)
y1 = y1.squeeze(0)
print(torch.equal(x0[0], x1[1]))
It prints False while
print(torch.equal(x0[1], x1[0]))
prints True.
But that makes sense and your rotation works fine (for some reason I assumed x had shape [2, 32] instead of [11, 32] and thus thought torch.equal(x0[0], x1[1]) should be True as well.)