I have to implement a custom backpropagation bc. I need to do some calculations with numpy. The problem is I will be doing the gradient calculation in the for
loop in the forward pass since I don’t want to go through another for
loop in the backward pass. However, the gradient calculation involves some complex indexing that takes a lot of time, so I don’t want to calculate gradient for inputs that don’t require grad. Thus, func
returns the gradient in grad
when backprop
is True
and returns None
in grad
when backprop
is False
by skipping gradient calculation. Likewise, I will save grad
for backward using ctx.save_for_backward
method only when backprop
is True
. My question is, is this the right method to implement custom gradients for only inputs that require gradient calculation?
class CustomGrad(Function):
@staticmethod
def forward(ctx, x, func):
backprop = x.requires_grad
result_collec = []
grad_collec = []
for b, data in enumerate(x): # iterate over batch
result, grad = func(data, backprop)
result_collec.append(result)
if grad is not None:
grad_collec.append(grad)
if backprop:
ctx.save_for_backward(torch.from_numpy(np.stack(grad)).to(torch.float32))
return torch.from_numpy(np.stack(result_collec)).to(torch.float32)
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
local_grad = ctx.saved_tensors
grad_input = # calculate downstream gradient using local grad and grad_output
return grad_input, None