Pytorch bmm on batches with channels


(Rajarshi Banerjee) #1

Suppose you have a Tensor a = [n, 3, h, w] and another tensor b = [n, h, w]
And you want to do this:

torch.stack((torch.bmm(a[:,0,:,:],b), torch.bmm(a[:,1,:,:],b), torch.bmm(a[:,2,:,:],b)), dim=1)     

is there any better way of doing that that is applying torch.bmm() on tensors where the batches have channels but each channel need to to multiplied(matrix multiplication) with the same matrix for each channel


#2

You could add an additional dimension at dim1 in b and let broadcasting do the rest:

a = torch.randn(10, 3, 24, 24)
b = torch.randn(10, 24, 24)

c = torch.stack((torch.bmm(a[:,0,:,:],b), torch.bmm(a[:,1,:,:],b), torch.bmm(a[:,2,:,:],b)), dim=1)     

d = torch.matmul(a, b.unsqueeze(1))
(c == d).all()

(Rajarshi Banerjee) #3

Thanks, I never really was quite good at understanding much less properly using Broadcasting