“tensor_a” is of size (1,3,4,4) and “tensor_b” is of size(3).

I want to subtract constant value from each channel. For below code, i want to subtract constant value 1 from 1st channel, 2 from 2nd channel and 3 from 3rd channel.

```
import torch
a=torch.Tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]],
[[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.],
[29., 30., 31., 32.]],
[[33., 34., 35., 36.],
[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]]]])
b=torch.tensor([1., 2., 3.])
print("a shape is ",a.shape)
print("b shape is ",b.shape)
```