Custom gradients only for inputs that require grad

I have to implement a custom backpropagation bc. I need to do some calculations with numpy. The problem is I will be doing the gradient calculation in the for loop in the forward pass since I don’t want to go through another for loop in the backward pass. However, the gradient calculation involves some complex indexing that takes a lot of time, so I don’t want to calculate gradient for inputs that don’t require grad. Thus, func returns the gradient in grad when backprop is True and returns None in grad when backprop is False by skipping gradient calculation. Likewise, I will save grad for backward using ctx.save_for_backward method only when backprop is True. My question is, is this the right method to implement custom gradients for only inputs that require gradient calculation?

class CustomGrad(Function):
    @staticmethod
    def forward(ctx, x, func):
        backprop = x.requires_grad
        result_collec = []
        grad_collec = []
        for b, data in enumerate(x):              # iterate over batch
            result, grad = func(data, backprop)
            result_collec.append(result)
            if grad is not None:
                grad_collec.append(grad)
        if backprop:
            ctx.save_for_backward(torch.from_numpy(np.stack(grad)).to(torch.float32))
        return torch.from_numpy(np.stack(result_collec)).to(torch.float32)
    
    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        local_grad = ctx.saved_tensors
        grad_input = # calculate downstream gradient using local grad and grad_output
        return grad_input, None

ctx has an attribute :attr:ctx.needs_input_grad as a tuple
of booleans representing whether each input needs gradient. E.g.,
:func:backward will have ctx.needs_input_grad[0] = True if the
first input to :func:forward needs gradient computed w.r.t. the
output.

If you need to know during backward whether particular inputs requires_grad or not during forward, you could use ctx.need_input_grad.

Thanks for the reply! Since I only call ctx.save_for_backward in the forward pass only when backprop=True, would a backward like this be fine?

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        grad_input = None
        if ctx.needs_input_grad[0]:
            grad_local = ctx.saved_tensors
            grad_input = # calculate downstream gradient using local grad and grad_output
        return grad_input, None

Yes, that should be fine

1 Like