It seems you want index a tensor of [N, C, seq_len]
with another tensor of [N, K]
, so I’m unsure where a mask is coming from.
If I understand the use case correctly, this should work:
N, C, seq_len, K = 2, 3, 4, 5
x = torch.randn(N, C, seq_len)
idx = torch.randint(0, seq_len, (N, K))
result = x[torch.arange(N)[:, None, None], torch.arange(C)[None, :, None], idx.unsqueeze(1)]
print(result.shape)
> torch.Size([2, 3, 5])