Subtract constant value from tensor

“tensor_a” is of size (1,3,4,4) and “tensor_b” is of size(3).

I want to subtract constant value from each channel. For below code, i want to subtract constant value 1 from 1st channel, 2 from 2nd channel and 3 from 3rd channel.

import torch

a=torch.Tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]],

         [[17., 18., 19., 20.],
          [21., 22., 23., 24.],
          [25., 26., 27., 28.],
          [29., 30., 31., 32.]],

         [[33., 34., 35., 36.],
          [37., 38., 39., 40.],
          [41., 42., 43., 44.],
          [45., 46., 47., 48.]]]])
b=torch.tensor([1., 2., 3.])
print("a shape is ",a.shape)
print("b shape is ",b.shape)

This should do the trick:

b = b.view(1, *b.size(), 1, 1).expand_as(a)