I have convinced myself (mostly) that this works as intended. I test training the regular and detached versions with (i) the loss from first model only and (ii) the loss from second model only:
x0 = torch.randn(128, 1)
x = Variable(x0)
y = Variable(x0, requires_grad=False)
modelA = nn.Linear(1, 32, bias=True)
modelB = nn.Linear(32, 1, bias=True)
model = nn.Sequential(modelA, modelB)
optimizer = optim.SGD(model.parameters(), lr=1e-1)
def lossA(z, _):
return z.mean()
def lossB(_, y_pred):
return ((y_pred - y) ** 2).mean()
# Regular training steps:
for n, loss_fn in [("lossA", lossA), ("lossB", lossB)]:
print(f"\nTraining with {n}:")
for t in range(3):
optimizer.zero_grad()
z = modelA(x)
y_pred = modelB(z)
A = sum([l.sum() for l in modelA.parameters()])
B = sum([l.sum() for l in modelB.parameters()])
loss = loss_fn(z, y_pred)
print(f"modelA weights: {A.item():.2f}, modelB weights: {B.item():.2f}")
loss.backward()
optimizer.step()
# Detached training steps:
for n, loss_fn in [("lossA", lossA), ("lossB", lossB)]:
print(f"\nDetached training with {n}:")
for t in range(3):
optimizer.zero_grad()
z = modelA(x)
y_pred = modelB(z.detach())
A = sum([l.sum() for l in modelA.parameters()])
B = sum([l.sum() for l in modelB.parameters()])
loss = loss_fn(z, y_pred)
print(f"modelA weights: {A.item():.2f}, modelB weights: {B.item():.2f}")
loss.backward()
optimizer.step()
Gives:
Training with lossA:
modelA weights: 2.73, modelB weights: 0.39
modelA weights: 2.65, modelB weights: 0.39
modelA weights: 2.56, modelB weights: 0.39
Training with lossB:
modelA weights: 2.48, modelB weights: 0.39
modelA weights: 2.51, modelB weights: 0.52
modelA weights: 2.48, modelB weights: 0.46
Detached training with lossA:
modelA weights: 2.50, modelB weights: 0.46
modelA weights: 2.41, modelB weights: 0.46
modelA weights: 2.33, modelB weights: 0.46
Detached training with lossB:
modelA weights: 2.24, modelB weights: 0.46
modelA weights: 2.24, modelB weights: 0.55
modelA weights: 2.24, modelB weights: 0.33
In the later “detached” case, training with loss from the second model B does not update modelA and vice versa. Sorry not very minimal but maybe useful?