After defining a custom layer with customized gradient computation, do I need to have the gradients with requires_grad=True? Or is it enough just to have the requires_grad as true for the input and the output of that layer?
Does it make a problem if unnecessarily I put requires_grad as True for some variables?
Thanks in advance!
It might create the the autograd state for Tensors that don’t need it. But it won’t cause any problem no.
If you define a custom autograd.Function, you can use it as any other function: The inputs should have requires_grad=True during the forward if you want to be able to backprop. The grad_output given to your custom backward will have requires_grad=True if you asked for
create_graph=True when calling backward and Tensors require gradients.
Thanks a lot! Where should I exactly use
create_graph=True? Do I really need it while I am not using
Here is my code, where it is defining some new custom loss class with all the computations inside
class CustomLoss(Function): # Inherit from Function
def forward(ctx, X_input, imgD):
Loss, Gradients = LossAndGradients(X_input, imgD)
def backward(ctx, grad_output):
Gradients, = ctx.saved_tensors
grad_imgD = None
grad_input = Gradients
return grad_input, grad_imgD
In the above function I have set
requires_grad=True for Loss and X_input, and
requires_grad=False for Gradients. And I employ it as below in the main code:
criterion = CustomLoss.apply
I am not really sure if it is back-propagating the gradients in a correct way during training.
You should use
create_graph=True when you want higher order derivatives (like hessian or gradient penalty) otherwise you don’t need it.
In the custom Function, the
forward always run in a no_grad mode. So you don’t need to check the
requires_grad field in there. The history will properly be attached to the Tensor you return on exit.
backward, it run with grad mode enabled only if you set
create_graph=True. If you don’t, then it will also not run in grad mode and you can ignore the requires_grad fields in there.