Suppose you have a Tensor a = [n, 3, h, w] and another tensor b = [n, h, w]

And you want to do this:

```
torch.stack((torch.bmm(a[:,0,:,:],b), torch.bmm(a[:,1,:,:],b), torch.bmm(a[:,2,:,:],b)), dim=1)
```

is there any better way of doing that that is applying torch.bmm() on tensors where the batches have channels but each channel need to to multiplied(matrix multiplication) with the same matrix for each channel