# Transfer Gradient Info between two Tensors that makes the two almost identical in the backward() pass

``````    # Transfer Gradient so backward() of 'target' takes the same path as 'source', & only differs
#   cuz the values of the Tensors are different
assert source.shape==target.shape
return target

# MSE Fast
def MSE(y, labels):
loss = (y - labels)**2
return T.mean(loss)
``````

Hi, I would really appreciate your help.

So I would like to transfer the gradient information from one Tensor to another so that the backward() pass would be identical, and the backward() pass would only differ since the values in the first tensor are different from the values in the second. I was hoping to do this to speed up the time and save memory for area’s that can be excluded in the backward pass. Take a cost function above as an example, where I skip the gradient calculation within the cost function, but re-apply it to the loss variable at the end. So that “loss.backward()” would calculate the gradients after the MSE() function is executed.

Thank you ,
Andrew

Hi @ptrblck_de, (I will message other people as well)

Would you mind helping me with this please, I would really appreciate your help, and better understand PyTorch.

I do not think that this question is too challenging or too complex.

Thank you,
Andrew

P.S. I realize I need to add a delta variable, and transfer the gradients from that variable.

``````delta = T.abs(y - labels)
loss = delta**2
...
``````

… Transfer the gradient of that variable:

``````loss = TransferGradient(delta, loss)
``````

Sorry, I could not edit my original post.

One thing that might be faster to do (you could try it and see if it actually is) is to define the tensor (that you want to use another tensor’s gradients) as a custom layer, where the init function takes a reference to the original tensor. Then, you can override the backward method to simply return the gradients of that original tensor.

You can have a look here for reference.

Hi,

As mentionned above, you can use a custom autograd.Function if you want to specify the backward pass for a given function.
See the doc here on how to do that.

Thank you! for the help.

I am not nearly as familiar with the backward() function, but I gave it a shot.How would you change this?

``````class TransferGradient(Function):
@staticmethod
def forward(ctx, source, target):
ctx.srce_ = source
ctx.tgt_ = target
return target

@staticmethod
Also as mentioned there, if you want to save input/output for backward, you need to use `ctx.save_for_backward()` to save these.