Memory efficient weighted sum of channels


  1. I = Tensor of dim (N, C, X) (Input)
  2. W = Tensor of dim (N, X, Y) (Weight)
  3. O = Tensor of dim (N, C, Y) (Output)

I want to compute:

I = I.view(N, C, X, 1)
W = W.view(N, 1, X, Y)
O = PROD.sum(dim=2)
return O

without incurring N * C * X * Y memory overhead.

Basically I want to calculate the weighted sum of a feature map wherein the weights are the same along the channel dimension, without incurring memory overhead per channel.

Maybe I could use

from itertools import product

O = torch.zeros(N, C, Y)
for n, x, y in product(range(N), range(X), range(Y):
    O[n, :, y] += I[n, :, x]*W[n, x, y]
return O

but that would be slower (no broadcasting) and I’m not sure how much memory overhead would be incurred by saving variables for the backward pass.

You can change the loop you’ve written to iterate over X only. I’m not 100% sure that this will work well with autograd, but I think it will (i.e. copies of O won’t be stored because of in-place addition, and I&W slicing overhead is constant). Well, there is an option to put this loop in autograd.Function or checkpoint… Unfortunately, I don’t know of better ways to do this, maybe addbmm.