Work around for large broadcasting cause CUDA out of memory

Hi, I have two tensors x with shape [B,B,C, 1] and y with shape [1, B, C, N].

I want to do torch.mean(x*y, dim=1) which will output a tensor with shape [B,C,N].

But the broadcasting result x*y with shape [B,B,C,N] is too large (12.76GB) due to large B and C. I tried with torch.einsum and intend to avoid the broadcasting product, but it seems torch.einsum needs to generate the intermediate x*y implicitly and results in the same OOM error.

Does anybody have a suggestion or work-around?

Thanks a lot.

perhaps batched matmul: (C, B1, B2) @ (C, B2, N) = (C,B1,N)