Back propagation trough slicing with list

I have a similar problem to this issue on github however no solution is proposed.

The proposed example code :

slize = [1, 2, 3, 4]
x = torch.randn(10, requires_grad=True)
y = x[slize]
# breaks second time calling backward
y.sum().backward() # rasises RuntimeErroe

In my case large tensor and I want to iterate gradient descent steps on mini batches.


old_log_probs = ... # size [8000, 1]
for epoch in range(epoch):
   for batch_idx in rollout.shuffle_index(batch_size=8):
         loss = ppo_loss(      
                        new_log_probs, # size [8, 1]
                        old_log_probs[batch_index],  # size [8, 1]
                        advantages[batch_index],  # size [8, 1]

Because when we get the subtensor with a list it is a copy and not a view.
Is it possible to get a subtensor with random indexes as a view?


There is a solution: set retain_graph=True in the first call to backward if you want to backrprop through it again later :slight_smile:
This behavior is expected.

Hi, I tried the solution you proposed but it doesn’t work.

Do you have other ideas ? here is my error log when retain_graph = True

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 2]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

When loss.backward() withour retain graph :

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

It does work :smiley: It is just that the code has another unrelated problem.
Running the updated sample works fine:

import torch
slize = [1, 2, 3, 4]
x = torch.randn(10, requires_grad=True)
y = x[slize]
y.sum().backward() # runs fine

As the error mentions, you modify inplace a Tensor whose value is required for the backward computation. You can follow the instructions in the error message to get a pointer to which op is faulty and replace the inplace op there by an out of place one.