How can I parallelize this module

I use the following code in my program. Can I avoid a loop within a batch? I have tried transposing some channels of qkv, but the output always failed to match correctly.

q = torch.rand(32, 256, 768)
k = torch.rand(64, 256, 768)
v = torch.rand(64, 256, 768)

# batch_first=False
cross_attn_0_to_1 = nn.MultiheadAttention(768, 8, dropout=0.0, batch_first=False)
res = []
for i in range(32):
    tmp = cross_attn_0_to_1(q[i].unsqueeze(0), k[i * 2:(i + 1) * 2], v[i * 2:(i + 1) * 2])[0].squeeze(0)
    res.append(tmp)
res = torch.stack(res)

May be like this

q = torch.rand(32, 256, 768)
k = torch.rand(64, 256, 768)
v = torch.rand(64, 256, 768)

# batch_first=False
cross_attn_0_to_1 = nn.MultiheadAttention(768, 8, dropout=0.0, batch_first=False)

q = q.view(32, 1, 256, 768)
k = k.view(32, 2, 256, 768)
v = v.view(32, 2, 256, 768)
res = torch.vmap(cross_attn_0_to_1)(q, k, v)[0].squeeze(1)
1 Like

Wow, I just tried this, it worked! Thanks a lot for your help!

1 Like