So I would like to perform a special matmul operation between
A in shape [batch, num_nodes, in_channels],
and
W in shape [num_nodes, in_channels, out_channels].
The shape of the result should be [batch, num_nodes, out_channels].
I tried my best and I can only get one solution, using a for loop (which of course is so slow…):
assert A.shape[1]==W.shape[0]
result = []
for i in range(A.shape[1]):
result.append(A[:, i, :] @ W[i, :, :])
return torch.stack(result, dim=1)
Anyone can give me a hand? Thanks in advance!