Broadcast pytorch array across channels based on another array

I have two arrays, x and y, with the same shape. x represents data and y represents which class each datapoint in x belongs to. I want to create a new tensor where the data in x are partitioned into channels based on their classes in y.

I can accomplish this if I use a one-hot encoding. However, for large tensors (especially with a large number of classes), PyTorch’s one-hot encoding quickly uses up all memory on the GPU.

Is there a more memory-efficient way to do this broadcasting?

import torch

B, C, N = 2, 10, 1000

x = torch.randn(B, 1, N)
y = torch.randint(low=0, high=C, size=(B, 1, N))

one_hot = torch.nn.functional.one_hot(y, C)  # B 1 N C
one_hot = one_hot.squeeze().permute(0, -1, 1)  # B C N

z = x * one_hot  # B C N

Great answer on stackoverflow for anyone in the future: python - Broadcast pytorch array across channels based on another array - Stack Overflow