Detach a intermediate variable without recomputing whole graph

Hi,

I have a situation that looks like this:

model = resnet50()
classifier = nn.Linear(1000, 20)

x = model(input)
out1 = classifier(x)
out2 = classifier(x.detach()) # Computation to avoid

where the classifier is expensive and would like to avoid recomputing out2_d = classifier(x.detach())

Can we use out1 to backward with x detached ?

Thank you

Hi,

You mean that you want to get the gradients for the elements in the classifier but not the ones in model?
If s, you can use autograd.grad to specify what you want the gradients for: grads = autograd.grad(loss, classifier.parameters()).

1 Like

Thank you for the quick response, actually my example was simplifying my case too much. Let me provide a better example.

I’m not trying to compute the gradient of the loss w.r.t to my classifier parameters only.

x1, x2 = model(x)
out = heavy_module(x1, x2)
loss = Loss1(out)
loss.backward()

out_d = heavy_module(x1.detach(), x2)
loss_d = Loss2(out_d)
loss_d.backward()

Can we avoid the second heavy_module computation ? while having access to both out and out_d for different losses.

Thank you

You won’t be able to have two Tensors if you do a single forward :smiley:
Why can’t you reuse out to compute loss_d?

If I use out to compute loss_d, then x1 won’t be detached and Loss2 will have gradients at the parameters that produced x1 (which I want to avoid).

Hi,

You would need two completely different graphs here I’m afraid. So you won’t be able to do that without re-doing the forward.

1 Like