# Transfer Gradient so backward() of 'target' takes the same path as 'source', & only differs
# cuz the values of the Tensors are different
def TransferGradient(source, target):
assert source.shape==target.shape
with T.no_grad():
target.grad_fn = source.grad_fn
target._grad_fn = source._grad_fn
target.requires_grad = True
return target
# MSE Fast
def MSE(y, labels):
with T.no_grad():
loss = (y - labels)**2
loss = TransferGradient(y, loss)
return T.mean(loss)
Hi, I would really appreciate your help.
So I would like to transfer the gradient information from one Tensor to another so that the backward() pass would be identical, and the backward() pass would only differ since the values in the first tensor are different from the values in the second. I was hoping to do this to speed up the time and save memory for area’s that can be excluded in the backward pass. Take a cost function above as an example, where I skip the gradient calculation within the cost function, but re-apply it to the loss variable at the end. So that “loss.backward()” would calculate the gradients after the MSE() function is executed.
One thing that might be faster to do (you could try it and see if it actually is) is to define the tensor (that you want to use another tensor’s gradients) as a custom layer, where the init function takes a reference to the original tensor. Then, you can override the backward method to simply return the gradients of that original tensor.
As mentionned above, you can use a custom autograd.Function if you want to specify the backward pass for a given function.
See the doc here on how to do that.
I am not nearly as familiar with the backward() function, but I gave it a shot.How would you change this?
class TransferGradient(Function):
@staticmethod
def forward(ctx, source, target):
ctx.srce_ = source
ctx.tgt_ = target
return target
@staticmethod
def backward(ctx, *grad_outputs):
# I do not think this is right. What would need to be changed here?
grad_outputs[0] = grad_outputs[0] * ctx.tgt_
return ctx.srce_.backward(ctx, *grad_outputs)
@staticmethod
def __call__(*args):
return TransferGradient.apply(*args)
The doc above explains that in details. You don’t need the call and use the apply versions of it.
Also as mentioned there, if you want to save input/output for backward, you need to use ctx.save_for_backward() to save these.