Splitting a model with detach

I have a model which can be decomposed into 2 sequential parts with corresponding combined loss function.

z = modelA(x)
y = modelB(z)
loss = function(z, y)

I would like to prevent the loss from modelB back propagating through modelA. My first thought was to use ‘detach()’ (specifically for the input to the second model), ie:

z = modelA(x)
y = modelB(z.detach())  # <-- detach added here
loss = function(z, y)

This runs and doesn’t do anything obscene but I am having trouble verifying this is working correctly. Can anyone speculate on if this is doing what i want?

For further context I am using a single optimizer, handed all the params. I appreciate conceptually I could train the models consecutively, but this is very convenient for the broader project.

I have convinced myself (mostly) that this works as intended. I test training the regular and detached versions with (i) the loss from first model only and (ii) the loss from second model only:

x0 = torch.randn(128, 1)
x = Variable(x0)
y = Variable(x0, requires_grad=False)
modelA = nn.Linear(1, 32, bias=True)
modelB = nn.Linear(32, 1, bias=True)
model = nn.Sequential(modelA, modelB)
optimizer = optim.SGD(model.parameters(), lr=1e-1)


def lossA(z, _):
    return z.mean()


def lossB(_, y_pred):
    return ((y_pred - y) ** 2).mean()


# Regular training steps:
for n, loss_fn in [("lossA", lossA), ("lossB", lossB)]:
    print(f"\nTraining with {n}:")
    for t in range(3):
        optimizer.zero_grad()
        z = modelA(x)
        y_pred = modelB(z)
        A = sum([l.sum() for l in modelA.parameters()])
        B = sum([l.sum() for l in modelB.parameters()])
        loss = loss_fn(z, y_pred)
        print(f"modelA weights: {A.item():.2f}, modelB weights: {B.item():.2f}")
        loss.backward()
        optimizer.step()


# Detached training steps:
for n, loss_fn in [("lossA", lossA), ("lossB", lossB)]:
    print(f"\nDetached training with {n}:")
    for t in range(3):
        optimizer.zero_grad()
        z = modelA(x)
        y_pred = modelB(z.detach())
        A = sum([l.sum() for l in modelA.parameters()])
        B = sum([l.sum() for l in modelB.parameters()])
        loss = loss_fn(z, y_pred)
        print(f"modelA weights: {A.item():.2f}, modelB weights: {B.item():.2f}")
        loss.backward()
        optimizer.step()

Gives:

Training with lossA:
modelA weights: 2.73, modelB weights: 0.39
modelA weights: 2.65, modelB weights: 0.39
modelA weights: 2.56, modelB weights: 0.39

Training with lossB:
modelA weights: 2.48, modelB weights: 0.39
modelA weights: 2.51, modelB weights: 0.52
modelA weights: 2.48, modelB weights: 0.46

Detached training with lossA:
modelA weights: 2.50, modelB weights: 0.46
modelA weights: 2.41, modelB weights: 0.46
modelA weights: 2.33, modelB weights: 0.46

Detached training with lossB:
modelA weights: 2.24, modelB weights: 0.46
modelA weights: 2.24, modelB weights: 0.55
modelA weights: 2.24, modelB weights: 0.33

In the later “detached” case, training with loss from the second model B does not update modelA and vice versa. Sorry not very minimal but maybe useful?