Inplace Broadcast Reduction Operation

I have an operations like this:

x = torch.randint(16, size=(128, 1, 3, 32, 32), dtype=torch.float32)
h = torch.randint(16, size=(1, 6, 3, 32, 32), dtype=torch.float32)
z = (x*h).sum(2) # output shape (128, 6, 32, 32)

Is there a way to avoid the memory allocation of the intermediary step x*h? A mix of broadcast, inplace and reduction on a single call?

Wondering if anyone would have a solution for this.

Also, the question isn’t quite clear. The issue is that the operation with broadcast would require more memory than available. So, the goal is not to avoid memory allocation, but to reduce instead of crashing.

You could replace the broadcasting with a loop iterating e.g. x, which would trade compute for memory.