How to Detach specific components in the loss?

.detach() will return a detached version of your tensor. You can reuse the “attached” tensor for further computations as long as you don’t reassign it somehow:

modelA = nn.Linear(10, 10)
modelB = nn.Linear(10, 10)
modelC = nn.Linear(10, 10)

x = torch.randn(1, 10)
a = modelA(x)
b = modelB(a.detach())
b.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)

c = modelC(a)
c.mean().backward()
print(modelA.weight.grad)
print(modelB.weight.grad)
print(modelC.weight.grad)
4 Likes