Splitting a model with detach

I have a model which can be decomposed into 2 sequential parts with corresponding combined loss function.

z = modelA(x)
y = modelB(z)
loss = function(z, y)

I would like to prevent the loss from modelB back propagating through modelA. My first thought was to use ‘detach()’ (specifically for the input to the second model), ie:

z = modelA(x)
y = modelB(z.detach())  # <-- detach added here
loss = function(z, y)

This runs and doesn’t do anything obscene but I am having trouble verifying this is working correctly. Can anyone speculate on if this is doing what i want?

For further context I am using a single optimizer, handed all the params. I appreciate conceptually I could train the models consecutively, but this is very convenient for the broader project.

I have convinced myself (mostly) that this works as intended. I test training the regular and detached versions with (i) the loss from first model only and (ii) the loss from second model only:

x0 = torch.randn(128, 1)
x = Variable(x0)
y = Variable(x0, requires_grad=False)
modelA = nn.Linear(1, 32, bias=True)
modelB = nn.Linear(32, 1, bias=True)
model = nn.Sequential(modelA, modelB)
optimizer = optim.SGD(model.parameters(), lr=1e-1)


def lossA(z, _):
    return z.mean()


def lossB(_, y_pred):
    return ((y_pred - y) ** 2).mean()


# Regular training steps:
for n, loss_fn in [("lossA", lossA), ("lossB", lossB)]:
    print(f"\nTraining with {n}:")
    for t in range(3):
        optimizer.zero_grad()
        z = modelA(x)
        y_pred = modelB(z)
        A = sum([l.sum() for l in modelA.parameters()])
        B = sum([l.sum() for l in modelB.parameters()])
        loss = loss_fn(z, y_pred)
        print(f"modelA weights: {A.item():.2f}, modelB weights: {B.item():.2f}")
        loss.backward()
        optimizer.step()


# Detached training steps:
for n, loss_fn in [("lossA", lossA), ("lossB", lossB)]:
    print(f"\nDetached training with {n}:")
    for t in range(3):
        optimizer.zero_grad()
        z = modelA(x)
        y_pred = modelB(z.detach())
        A = sum([l.sum() for l in modelA.parameters()])
        B = sum([l.sum() for l in modelB.parameters()])
        loss = loss_fn(z, y_pred)
        print(f"modelA weights: {A.item():.2f}, modelB weights: {B.item():.2f}")
        loss.backward()
        optimizer.step()

Gives:

Training with lossA:
modelA weights: 2.73, modelB weights: 0.39
modelA weights: 2.65, modelB weights: 0.39
modelA weights: 2.56, modelB weights: 0.39

Training with lossB:
modelA weights: 2.48, modelB weights: 0.39
modelA weights: 2.51, modelB weights: 0.52
modelA weights: 2.48, modelB weights: 0.46

Detached training with lossA:
modelA weights: 2.50, modelB weights: 0.46
modelA weights: 2.41, modelB weights: 0.46
modelA weights: 2.33, modelB weights: 0.46

Detached training with lossB:
modelA weights: 2.24, modelB weights: 0.46
modelA weights: 2.24, modelB weights: 0.55
modelA weights: 2.24, modelB weights: 0.33

In the later “detached” case, training with loss from the second model B does not update modelA and vice versa. Sorry not very minimal but maybe useful?

Yes, this should do you what you want, you can validate this by running backward and checking that the .grads in the params of modelA are not updated.