Hi, I have two tensors x with shape [B,B,C, 1] and y with shape [1, B, C, N].
I want to do torch.mean(x*y, dim=1)
which will output a tensor with shape [B,C,N].
But the broadcasting result x*y
with shape [B,B,C,N] is too large (12.76GB) due to large B and C. I tried with torch.einsum
and intend to avoid the broadcasting product, but it seems torch.einsum
needs to generate the intermediate x*y
implicitly and results in the same OOM error.
Does anybody have a suggestion or work-around?
Thanks a lot.