Copying Graph between Instances

Hi all,

I want to copy the computational graph between nn.Parameter tensors. I think the MWE below illustrates my use case in its simplest form.

class Test(torch.nn.Module):
    def __init__(self):
        self.p = nn.Parameter(torch.tensor([1.]))

t = Test()
t1 = Test()

t1.p = ( * )

(t1.p ** 2).backward()
assert t1.p.grad == t.p.grad

What should we put in ( * ) to get the same gradients for t1.p and t.p, i.e. to have t1.p.grad == t.p.grad?
Thank you very much in advance for your help!


I think the simplest thing to do would be:

t1.p.grad = torch.zeros_like(t1.p)
t.p.grad = t1.p.grad