Updates:
I came with something like this:
import torch
from torch.nn.utils.rnn import pad_sequence
idx = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1], [2, 2, 3, 3], [2, 2, 2, 2]])
feat = torch.tensor([[11, 12, 21, 22], [13, 14, 15, 23], [31, 32, 41, 42], [33, 34, 35, 36]])
# turn index into one hot
onehot = torch.zeros(4, 4, 4)
onehot.scatter_(0, idx.view(1, 4, 4), 1)
# extract features for each cluster
onehot = onehot.view(4,-1)
feat = feat.view(1, -1)
sub = torch.masked_select(feat, onehot == 1)
# padding
cluster_size = torch.sum(onehot, dim=1).long().numpy().tolist()
clusters = torch.split(sub, cluster_size)
padded_feat = pad_sequence(clusters, batch_first=True)
print(padded_feat)
It seems working, although it’s impossible to be batched.
Original post:
I don’t know if scatter concatenation is the correct name, but I would like to do the following operation:
Suppose I have an index matrix Idx and a feature matrix F:
Idx = [[0, 0, 1, 1], [0, 0, 0, 1], [2, 2, 3, 3], [2, 2, 2, 2]]
F = [[a1,a2,b1,b2], [a3,a4,a5,b3], [c1,c2,d1,d2], [c3,c4,c5,c6]]
I would like to take the features of each cluster defined in the index and put them in a row with padding wherever necessary, so the result looks like:
out = [[a1,a2,a3,a4,a5,0], [b1,b2,0,0,0,0], [c1,c2,c3,c4,c5,c6], [d1,d2,0,0,0,0]]
It’s like a scatter
operation, but without the reduction. The order of each row in the output doesn’t matter, i.e. [a1,a2,a3,a4,a5,0]
or [a2,a1,a4,a3,a5,0]
are both Okay.
I suppose I can for loop and reconstruct each row in the output:
out = torch.zeros(4, 6)
for i in range(4):
data = F[Idx==I]
out[:, :data.shape[0]] = data
but I am wondering if there’s more efficient way to do this. Any suggestion would be appreciated.