Pytorch bmm on batches with channels

Suppose you have a Tensor a = [n, 3, h, w] and another tensor b = [n, h, w]
And you want to do this:

torch.stack((torch.bmm(a[:,0,:,:],b), torch.bmm(a[:,1,:,:],b), torch.bmm(a[:,2,:,:],b)), dim=1)     

is there any better way of doing that that is applying torch.bmm() on tensors where the batches have channels but each channel need to to multiplied(matrix multiplication) with the same matrix for each channel

You could add an additional dimension at dim1 in b and let broadcasting do the rest:

a = torch.randn(10, 3, 24, 24)
b = torch.randn(10, 24, 24)

c = torch.stack((torch.bmm(a[:,0,:,:],b), torch.bmm(a[:,1,:,:],b), torch.bmm(a[:,2,:,:],b)), dim=1)     

d = torch.matmul(a, b.unsqueeze(1))
(c == d).all()
2 Likes

Thanks, I never really was quite good at understanding much less properly using Broadcasting