During backprop, is it possible to train only a part of the model?

Suppose my model consists of three parts:

h = f1(x)
h' = f2(h)
y = f3(h')

where f1 is a pre-trained model (I want to fix it during training), f2 is not differentiable and doesn’t have any learnable parameters, and f3 is the only part that I want to learn.

When I run loss.backward(), will autograd try to backprop gradients all the way back to x? How can I ensure that backprop stops at h'?

I know how to do that in torch but I am new with pytorch. Thanks in advance!

yes it will propagate gradients all the way back to x.
An easy way to stop at h' is by calling detach_:

h = f1(x)
h' = f2(h)
h'.detach_()
y = f3(h')
y.backward()
1 Like

Thanks! That’s exactly what I want :slight_smile: