I have two arrays, x
and y
, with the same shape. x
represents data and y
represents which class each datapoint in x
belongs to. I want to create a new tensor where the data in x
are partitioned into channels based on their classes in y
.
I can accomplish this if I use a one-hot encoding. However, for large tensors (especially with a large number of classes), PyTorch’s one-hot encoding quickly uses up all memory on the GPU.
Is there a more memory-efficient way to do this broadcasting?
import torch
B, C, N = 2, 10, 1000
x = torch.randn(B, 1, N)
y = torch.randint(low=0, high=C, size=(B, 1, N))
one_hot = torch.nn.functional.one_hot(y, C) # B 1 N C
one_hot = one_hot.squeeze().permute(0, -1, 1) # B C N
z = x * one_hot # B C N