Autoregressive sampling with Multi-head attention

I am trying to make sure I understand the implementation of torch.nn.MultiheadAttention(), as I want to use it for autoregressive sampling in a “decoder-only” image transformer.

In such case, one tries to predict the next pixel by attending on all previous pixels. Say my pixels are just binary (0,1).

In the Multihead attention forward function, the input is supposed to be:
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]

What is “target length” referring to in this context? If I am just trying to predict the next pixel at each pass, is this just length 1?

Also, is there any way to do batch_first with this module?

Thanks for your help!

maybe this is wrong, but for example if we have image,

input = torch.randn(1, 1, 28, 28)

we pass it through convolutional layer,

cnv = nn.Conv2d(1, 10, 3, padding=1)
cnv(input).shape

torch.Size([1, 10, 28, 28])

q = torch.randn(1, 1, 10) # target pixel we want
k = cnv(input).reshape(28*28, 1, 10) # input pixels
v = cnv(input).reshape(28*28, 1, 10) # input pixels
attn = nn.MultiheadAttention(10, 2)
attn(q, k, v)[0].shape, attn(q, k, v)[1].shape

(torch.Size([1, 1, 10]), torch.Size([1, 1, 784]))

first one is representation of pixel that we want, second one is how much this pixel representation depended on each of 28*28 pixels.

lin1 = nn.Linear(10, 1)
lin1(attn(q, k, v)[0])

tensor([[[0.2838]]], grad_fn=)

this is value of next pixel, let us call it p.

next time we run this, we want to use this value also, so,

lin2 = nn.Linear(1, 10)
lin2(p)

would give something like [1, 1, 10], our next k, v would be

k = torch.cat((k, lin2(p)), dim=-3)
v = torch.cat((v, lin2(p)), dim=-3)

here is one implementation,

class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.attn = nn.MultiheadAttention(10, 2)
    self.conv = nn.Conv2d(1, 10, 3, padding=1)
    self.lin1 = nn.Linear(10, 1)
    self.lin2 = nn.Linear(1, 10)
  
  def forward(self, input, nxt_pxl):
    if(input.ndim == 4):
      conv_out = self.conv(input) # out.shape -> (1, 10, 28, 28)
      conv_out = conv_out.reshape(conv_out.size(2)*conv_out.size(3), conv_out.size(0), conv_out.size(1))
    else:
      conv_out = input
    out = self.attn(nxt_pxl, conv_out, conv_out)
    out = self.lin1(out[0])
    out1 = self.lin2(out) # next pixel
    out2 = torch.cat((conv_out, out1), dim=-3) # next input
    return out, out1.detach(), out2.detach()
model = Net()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
input = torch.randn(1, 1, 28, 28)
target = torch.tensor([1.])
nxt_pxl = torch.randn(1, 1, 10) # first time we specify random next pixel
for i in range(10):
  optimizer.zero_grad()  
  print(input.shape)
  pred, nxt_pxl, input = model(input, nxt_pxl)
  loss = loss_fn(pred, target)
  print(pred.shape)
  print(loss)
  loss.backward()
  optimizer.step()