Hi there!
I have a 3D tensor indices like [16,12,12] and a 4D tensor values like [16,12,128,64].(sizes may vary but dimensions are fixed)
I wanted to keep values based on indices and zero out others.
I can write it with for loops like the following but it’s slow and need to take advantage of torch tensors.
for i in range(0, values.size(0)):
for j in range(0, values.size(1)):
for k in range(0, values.size(2)):
if k in indices[i][j]:
q[i][j][k] = values[i][j][k]
else:
q[i][j][k] = torch.zeros(values.size(3))