Debug and Trace the Backward Process

Recently I have implemented a novel function and self-specified backward process. However, I have met the problem of RuntimeError: No grad accumulator for a saved leaf!

Therefore, I want to know if there are methods to:

  • Insert code and trace the backward process like forward

  • Why this error occurs?

The core part of my code is as follows

class MixedOp(nn.Module):
    MODE = None
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        # candidate_ops is a list of pytorch modules, such as pooling and convolution
        self.candidate_ops = nn.ModuleList()
        # insert operations to the candidate_ops
        ......
    
    def forward(self, x):
        def run_function(candidate_ops, active_id):
            def forward(_x):
                return candidate_ops[active_id](_x)
            return forward
        
        # backward here is for the update of weight parameters
        def backward_function(candidate_ops, active_id, binary_gates):
            def backward(_x, _output, grad_output):
                binary_grads = torch.zeros_like(binary_gates.data)
                with torch.no_grad():
                    for k in range(len(candidate_ops)):
                        if k != active_id:
                            out_k = candidate_ops[k](_x.data)
                        else:
                            out_k = output.data
                        grad_k = torch.sum(out_k * grad_output)
                        binary_grads[k] = grad_k
                return binary_grads
            return backward
        
        output = ArchGradientFunction.apply(
            x, self.alpha_gate, run_function(self.candidate_ops, self.active_index[0]),
            backward_function(self.candidate_ops, self.active_index[0], self.alpha_gate)
        )
        return output


class ArchGradientFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, binary_gates, run_func, backward_func):
        ctx.run_func = run_func
        ctx.backward_func = backward_func

        detached_x = detach_variable(x)
        with torch.enable_grad():
            output = run_func(detached_x)
        ctx.save_for_backward(detached_x, output)
        return output.data
    
    @staticmethod
    def backward(ctx, grad_output):
        detached_x, output = ctx.saved_tensors

        grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
        # compute gradients w.r.t. binary_gates
        binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)

        return grad_x[0], binary_grads, None, None

For the reference, the whole error message is as follows

Traceback (most recent call last):
  File "train_search.py", line 267, in <module>
    main()
  File "train_search.py", line 160, in main
    optimizer_weight, optimizer_arch, lr, train_arch=False, args=args)
  File "train_search.py", line 218, in train
    loss.backward()
  File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages/torch/autograd/function.py", line 76, in apply
    return self._forward_cls.backward(self, *args)
  File "/home/ma-user/work/ziqipang/darts-binary/models/mix_op.py", line 178, in backward
    detached_x, output = ctx.saved_tensors
RuntimeError: No grad accumulator for a saved leaf!

Hi,

I am not sure what you are trying to do here, but mixing of .data (which is not a thing anymore) and detaching things inside a custom function is quite dangerous.

You should never use .data as a general rule. If you want to get a new Tensor with no history, you should use .detach().
save_for_backward should only be called with either inputs or outputs to the Function.
History is not tracked through the save_for_backward/saved_tensors, so you cannot do this and expect the grad call in your backward to work.

What are you trying to do here? Only to set some gradients for binary_gates even though it was not used in the run_func?

Hi! Really thanks for your concern! My stuff is a little complex, so I’ll try my best brief it up.

I am doing Neural Architecture Search specifically. The MixedOp here means an edge between two nodes, and it contains several candidate operations. During every iteration, I select one candidate operation for the forward process. The operation is specified by the variable self.alpha_gate , and the active one has id active_index[0] . And during the backward, I want to compute the gradients w.r.t the binary gates, i.e. self.alpha_gate .

I copied the equation of computing the gradients from an open-source paper, so I think it is fine. I am only confused about the Runtime Error.

Can you confirm that I understood what you want properly:

  • output is computed as candidate_ops[active_idx](x).
    • Where does active_idx come from?
    • The gradient should be computed as the gradients of the selected op (which is differentiable so autograd is good here.
  • active_idx is just used for indexing but you want to get gradients for it
    • Since this is not differentiable, you want to use a custom Function for that.
    • The gradient formula is just sum(output_of_that_op * grad_output) for each score associated with an op (whether or not it was the selected one).

Thanks! I think the first point of you is correct. But the second might be slightly different from what I manage to do.

I use binary gates self.alpha_gate, which is a tensor composed of 0 and 1, to specify the usage of candidate operations. During the forward process, the final result can be treated as the multiplication of self.alpha_gate and the output of each candidate operation. Then finally, I want to compute the gradient w.r.t self.alpha_gate.

However, during the computation, I want to save the time of forward computation of all the candidate operations, as there is only one of them contributing to the final output. Therefore, I derive active_idx from the location of 1 from self.alpha_gate, then compute only the output of this active operation. Then during the backward process, I still want to compute the gradient towards self.alpha_gate.

I think that is roughly what I try to do.

I use binary gates self.alpha_gate , which is a tensor composed of 0 and 1, to specify the usage of candidate operations.

Do you save these results as self. in the nn.Module during the forward?? Because you read active index from self.active_index[0] which is not computed here and self.alpha_gate is a different Tensor from what I can see.
In general, you want your nn.Modules to be as stateless as possible. So you might want to avoid saving forward values as self..

I want to save the time of forward computation of all the candidate operations, as there is only one of them contributing to the final output.

Why does it save time given that you do these computations in the backward anyway? You could do them in the forward and just save the result no?

Alright! I think I somewhat understand the problem here.

I will update when I make prorgess.

Thank you very much!

1 Like