How to efficiently remove inplace operations?

I have a pytorch code which performs batched gradient descent, but the problem is that it contains an inplace operation which prevents me from using the autograd.

More specifically I have determined that this line is the problem:
y[idx] = y[idx] - alpha[idx,None,None] * dy

I have managed to replace the operation with the following:

            ysel = y[idx] - alpha[idx,None,None] * dy
            yall = []
            count = 0
            for i in range(y.shape[0]):
                if i in idx:
                    count = count + 1
            y = torch.stack(yall,dim=0)

But this is horribly inefficient so I’m wondering whether anyone has a better solution for how to extract elements from a tensor and insert new values back into it, without looping through each value and determining whether to update it or keep the original value.

Any ideas for this particular problem, or general information on how to replace inplace operations efficiently would be much appreciated.

You could check the newly added functorch.experimental.functionalize method which was added in PyTorch 1.12.0 and might help for your use case.

1 Like

I haven’t tested this yet, but from reading about it, it looks perfect, thank you very much!