can somebody tell me how do I implement stand alone self attention in PyTorch.
Here is my attempt, I found one implementation on github, but I think it is wrong.
q = torch.randn(1, 1, 3, 3) # for simplicity assume one image, one channel, HXW = 3x3
k = torch.randn(1, 1, 3, 3)
v = torch.randn(1, 1, 3, 3)
l_q = nn.Linear(1, 20)
l_k = nn.Linear(1, 20)
l_v = l_k
transformed_q = l_q(q.view(1, 1, 3, 3, 1))
transformed_k = l_k(k.view(1, 1, 3, 3, 1))
transformed_v = l_v(v.view(1, 1, 3, 3, 1))
transformed_k.shape # each pixel is 20 size embedding
torch.Size([1, 1, 3, 3, 20])
t_q = transformed_q.reshape(1, 1, 9, 20)
t_k = transformed_k.reshape(1, 1, 9, 20)
t_v = transformed_v.reshape(1, 1, 9, 20)
emb = nn.Embedding(9, 20) # for each pixel we have 9 neighbouring pixels
softmax = nn.Softmax(dim=-1)
attn_output_weights = softmax(t_q@t_k.transpose(2, 3) +
t_q@emb(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])).transpose(0, 1))
attn = attn_output_weights@t_v
how do I use relative position embedding, that is row, column offset, I think what I am doing is wrong.