Is there a way to link gradient of a tensor to another one?
import torch
fc = torch.nn.Linear(256, 256)
A = torch.nn.Embedding(10, 256)
a = A.weight
for _ in range(5):
a = fc(a)
b = a.clone().detach()
b.requires_grad = True
# [1] is there a way to link gradient of b to A here:
# something like, A.grad <- b.grad?
B = fc(b)
loss_B = B.sum()
loss_B.backward()
# [2] copying b's grad to A
A.grad = b.grad.clone()
I am hoping to use the gradient of b
to update A
. So what I can now think of is [2]
, which copies the gradient of b to A after calculating the grad using the backward
function.
But is there a way that I can link the grad of b
to A
before backward
, something like in [1]
?