I have a 3D tensor indices like [16,12,12] and a 4D tensor values like [16,12,128,64].(sizes may vary but dimensions are fixed)
I wanted to keep values based on indices and zero out others.
I can write it with for loops like the following but it’s slow and need to take advantage of torch tensors.
for i in range(0, values.size(0)): for j in range(0, values.size(1)): for k in range(0, values.size(2)): if k in indices[i][j]: q[i][j][k] = values[i][j][k] else: q[i][j][k] = torch.zeros(values.size(3))