.detach()
will return a detached version of your tensor. You can reuse the “attached” tensor for further computations as long as you don’t reassign it somehow:
modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
modelC = nn.Linear(10, 10)
x = torch.randn(1, 10)
a = modelA(x)
b = modelB(a.detach())
b.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)
c = modelC(a)
c.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)