How to do matmul with [batch, nodes, in] and [nodes, in, out], and get [batch, nodes, out]

So I would like to perform a special matmul operation between

A in shape [batch, num_nodes, in_channels],

and

W in shape [num_nodes, in_channels, out_channels].

The shape of the result should be [batch, num_nodes, out_channels].

I tried my best and I can only get one solution, using a for loop (which of course is so slow…):

assert A.shape[1]==W.shape[0]

result = []

for i in range(A.shape[1]):
    result.append(A[:, i, :] @ W[i, :, :])

return torch.stack(result, dim=1)

Anyone can give me a hand? Thanks in advance!

torch.einsum(‘bni,nio->bno’, A, W)

1 Like

Thanks for your timely reply! May I ask Can PyTorch’s Autograd handle this fancy operation well?

Thanks for your timely reply! May I ask Can PyTorch’s Autograd handle this fancy operation well?

yesd :slight_smile:

Great! Thank you sooooo much :smile: