# Transfer Gradient so backward() of 'target' takes the same path as 'source', & only differs
# cuz the values of the Tensors are different
def TransferGradient(source, target):
target.grad_fn = source.grad_fn
target._grad_fn = source._grad_fn
target.requires_grad = True
# MSE Fast
def MSE(y, labels):
loss = (y - labels)**2
loss = TransferGradient(y, loss)
Hi, I would really appreciate your help.
So I would like to transfer the gradient information from one Tensor to another so that the backward() pass would be identical, and the backward() pass would only differ since the values in the first tensor are different from the values in the second. I was hoping to do this to speed up the time and save memory for area’s that can be excluded in the backward pass. Take a cost function above as an example, where I skip the gradient calculation within the cost function, but re-apply it to the loss variable at the end. So that “loss.backward()” would calculate the gradients after the MSE() function is executed.
Thank you ,
Hi @ptrblck_de, (I will message other people as well)
Would you mind helping me with this please, I would really appreciate your help, and better understand PyTorch.
I do not think that this question is too challenging or too complex.
P.S. I realize I need to add a delta variable, and transfer the gradients from that variable.
So before no_grad():
delta = T.abs(y - labels)
loss = delta**2
… Transfer the gradient of that variable:
loss = TransferGradient(delta, loss)
Sorry, I could not edit my original post.
One thing that might be faster to do (you could try it and see if it actually is) is to define the tensor (that you want to use another tensor’s gradients) as a custom layer, where the init function takes a reference to the original tensor. Then, you can override the backward method to simply return the gradients of that original tensor.
You can have a look here for reference.
As mentionned above, you can use a custom autograd.Function if you want to specify the backward pass for a given function.
See the doc here on how to do that.
Thank you! for the help.
I am not nearly as familiar with the backward() function, but I gave it a shot.How would you change this?
def forward(ctx, source, target):
ctx.srce_ = source
ctx.tgt_ = target
def backward(ctx, *grad_outputs):
# I do not think this is right. What would need to be changed here?
grad_outputs = grad_outputs * ctx.tgt_
return ctx.srce_.backward(ctx, *grad_outputs)
Thank you again.
The doc above explains that in details. You don’t need the call and use the apply versions of it.
Also as mentioned there, if you want to save input/output for backward, you need to use
ctx.save_for_backward() to save these.