Convolve function channel wise over tensor

Is there any way to convolve a function channel-wise over a tensor?
I have a tensor of size u = torch.size([8,16,32,32]) = (N,C,H,W)
and trainable parameters:

mu = torch.Size([16])
sigma = torch.Size([16])

Batch-wise, to every channel in the tensor I want to apply the function:

    def growth_func(self, u, mu, sigma):
        return 2 * torch.exp(-(u - mu) ** 2 / (2 * sigma ** 2)) - 1

where every channel gets its own mu and sigma, and then convolve it over the 32x32 image size.
Would I have to expand the dimensions of mu and sigma?

Hi Etienne!

I’m not sure what you mean by "“convolve.”

I think you are asking how to apply your function element-wise to your
tensor u, taking into account that each of the 16 channels has its own
value of mu and sigma.

(To me, “convolve” implies that you have a sliding window that mixes
neighboring values in the tensor u together.)

If I understand properly what you want to do, you do not need to
expand() any dimensions of mu and sigma. Instead you only need
to add singleton dimensions (“trivial” dimensions of length one that
don’t require any additional storage) to mu and sigma and let pytorch
broadcast them over the non-channel dimensions of u.

Consider:

>>> import torch
>>> print (torch.__version__)
2.1.2
>>>
>>> u = torch.ones (8, 16, 32, 32)
>>> mu = torch.arange (1.0, 17.0) / 16.0
>>> sigma = torch.arange (1.0, 17.0) / 8.0
>>>
>>> mu.shape
torch.Size([16])
>>> muB = mu[None, :, None, None]                        # add singleton dimensions for non-channel dimensions
>>> muB.shape
torch.Size([1, 16, 1, 1])
>>>
>>> sigmaB = sigma[None, :, None, None]                  # add singleton dimensions for non-channel dimensions
>>>
>>> result = 2 * torch.exp(-(u - muB) ** 2 / (2 * sigmaB ** 2)) - 1   # broadcast along non-channel dimensions
>>>
>>> result.shape
torch.Size([8, 16, 32, 32])
>>> result[0, :, 0, 0]
tensor([-1.0000, -0.9956, -0.8087, -0.3507,  0.0921,  0.4133,  0.6266,  0.7650,
         0.8543,  0.9120,  0.9490,  0.9724,  0.9867,  0.9949,  0.9989,  1.0000])

Best.

K. Frank

Thank you!
This is indeed what I meant, I thought you might need to do a convolution of kernel size 1 but this seems much simpler.

Should I Change the title of my question to help others with the same issue, or leave it as is?