Inputs:

- I = Tensor of dim (N, C, X) (Input)
- W = Tensor of dim (N, X, Y) (Weight)

Output: - O = Tensor of dim (N, C, Y) (Output)

I want to compute:

```
I = I.view(N, C, X, 1)
W = W.view(N, 1, X, Y)
PROD = I*W
O = PROD.sum(dim=2)
return O
```

without incurring **N * C * X * Y memory overhead**.

Basically I want to calculate the weighted sum of a feature map wherein the weights are the same along the channel dimension, without incurring memory overhead per channel.

Maybe I could use

```
from itertools import product
O = torch.zeros(N, C, Y)
for n, x, y in product(range(N), range(X), range(Y):
O[n, :, y] += I[n, :, x]*W[n, x, y]
return O
```

but that would be slower (no broadcasting) and I’m not sure how much memory overhead would be incurred by saving variables for the backward pass.