Inputs:
- I = Tensor of dim (N, C, X) (Input)
- W = Tensor of dim (N, X, Y) (Weight)
Output: - O = Tensor of dim (N, C, Y) (Output)
I want to compute:
I = I.view(N, C, X, 1)
W = W.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
return O
without incurring N * C * X * Y memory overhead.
Basically I want to calculate the weighted sum of a feature map wherein the weights are the same along the channel dimension, without incurring memory overhead per channel.
Maybe I could use
from itertools import product
O = torch.zeros(N, C, Y)
for n, x, y in product(range(N), range(X), range(Y):
O[n, :, y] += I[n, :, x]*W[n, x, y]
return O
but that would be slower (no broadcasting) and I’m not sure how much memory overhead would be incurred by saving variables for the backward pass.