Replace tensor values at partial indexes

I want to make the value of a specific index equal to 0. but very very slow…

mask = [1,2,3,4]

input = torch.zeros(100,64,32,32)

for m in mask:
   for i in range(100):
      input[i][m] = 0

How do I optimize my code?

This code should work:

x = torch.zeros(100, 64, 32, 32)
mask = torch.tensor([1,2,3,4])
x[:, mask] = 1

Note that I changed the assignment to a 1, otherwise you would reassign a 0 to the tensor containing all zeros. :wink:

1 Like

Good JOB!! Many Thanks!!!

How to do this in the last dimension of the tensor?