maybe this is wrong, but for example if we have image,
input = torch.randn(1, 1, 28, 28)
we pass it through convolutional layer,
cnv = nn.Conv2d(1, 10, 3, padding=1)
cnv(input).shape
torch.Size([1, 10, 28, 28])
q = torch.randn(1, 1, 10) # target pixel we want
k = cnv(input).reshape(28*28, 1, 10) # input pixels
v = cnv(input).reshape(28*28, 1, 10) # input pixels
attn = nn.MultiheadAttention(10, 2)
attn(q, k, v)[0].shape, attn(q, k, v)[1].shape
(torch.Size([1, 1, 10]), torch.Size([1, 1, 784]))
first one is representation of pixel that we want, second one is how much this pixel representation depended on each of 28*28 pixels.
lin1 = nn.Linear(10, 1)
lin1(attn(q, k, v)[0])
tensor([[[0.2838]]], grad_fn=)
this is value of next pixel, let us call it p.
next time we run this, we want to use this value also, so,
lin2 = nn.Linear(1, 10)
lin2(p)
would give something like [1, 1, 10], our next k, v would be
k = torch.cat((k, lin2(p)), dim=-3)
v = torch.cat((v, lin2(p)), dim=-3)
here is one implementation,
class Net(nn.Module):
def __init__(self):
super().__init__()
self.attn = nn.MultiheadAttention(10, 2)
self.conv = nn.Conv2d(1, 10, 3, padding=1)
self.lin1 = nn.Linear(10, 1)
self.lin2 = nn.Linear(1, 10)
def forward(self, input, nxt_pxl):
if(input.ndim == 4):
conv_out = self.conv(input) # out.shape -> (1, 10, 28, 28)
conv_out = conv_out.reshape(conv_out.size(2)*conv_out.size(3), conv_out.size(0), conv_out.size(1))
else:
conv_out = input
out = self.attn(nxt_pxl, conv_out, conv_out)
out = self.lin1(out[0])
out1 = self.lin2(out) # next pixel
out2 = torch.cat((conv_out, out1), dim=-3) # next input
return out, out1.detach(), out2.detach()
model = Net()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
input = torch.randn(1, 1, 28, 28)
target = torch.tensor([1.])
nxt_pxl = torch.randn(1, 1, 10) # first time we specify random next pixel
for i in range(10):
optimizer.zero_grad()
print(input.shape)
pred, nxt_pxl, input = model(input, nxt_pxl)
loss = loss_fn(pred, target)
print(pred.shape)
print(loss)
loss.backward()
optimizer.step()