I use the following code in my program. Can I avoid a loop within a batch? I have tried transposing some channels of qkv, but the output always failed to match correctly.
q = torch.rand(32, 256, 768)
k = torch.rand(64, 256, 768)
v = torch.rand(64, 256, 768)
# batch_first=False
cross_attn_0_to_1 = nn.MultiheadAttention(768, 8, dropout=0.0, batch_first=False)
res = []
for i in range(32):
tmp = cross_attn_0_to_1(q[i].unsqueeze(0), k[i * 2:(i + 1) * 2], v[i * 2:(i + 1) * 2])[0].squeeze(0)
res.append(tmp)
res = torch.stack(res)