Skip the calculation of gradients for some layer after forward pass

Hello,

Let’s say I have 6 (transformer) layers (i.e., L0, L1, L2, L3, L4, L5), all the transformer layer now have require_grads = True, Now I do a forward pass to all the layers. Based on the representations, I might don’t want to calculate the gradients for some layers (to save some time) let say (L0, L1 and L2) i.e., I want to stop the gradient calculation at L2.

To perform this, initially I have all the layers have require_grads = True, after forward pass, let say I don’t want to calculate the gradients for L0, L1 & L2, I put the L0, L1 & L2 require_grads = False. Even after doing this the program takes the same time to perform the task (I have check the gradient they are “None”, they not calculated). Can anyone help me out with this. What I am doing wrong here. Or is there any other way to do this?

Note: I put the word embedding layer require_grads = False.

Code:

for param in student_model.distilbert.embeddings.parameters():

param.requires_grad = False

for param in student_model.distilbert.transfomer.layer[:curr_flag].parameters():

param.requires_grad = False

Thank you in advance.

Hi Lipril!

If I understand your use case correctly, you want to perform a forward pass through all

layers, L0, L1, …, L5, and then based of the result of the full forward pass (that is, the

output of L5), you decide whether or not to compute gradients for L0, L1, and L2.

Since your decision depends on the full forward pass, you have to incur the cost of the

computing L0, …, L2, but you don’t necessarily have to incur the cost of tracking those

gradients. If you usually will be computing gradients for L0, …, L2, you should do something

like:

L2_out = L2 (L1 (L0 (input)))
L2_out_detach = L2_out.detach().requires_grad_ (True)
output = L5 (L4 (L3 (L2_out_detach)))
loss = loss_fn (output)
compute_grads = should_i_compute_grads_for_l012 (output)
loss.backward()    # populates grads for L3, L4, and L5, and L2_out_detach
if  compute_grads:
    torch.autograd.backward (L2_out, grad_tensors = L2_out_detach.grad)   # populates grads for L0, L1, and L2

This scheme incurs the cost of computing L0, …, L2 (which you have to do). It also incurs

the cost of constructing the computation graph for L0, …, L2 whether or not you end up

using it. If you usually end up backpropagating through L0, …, L2, this should be okay.

If you usually don’t end up computing grads for L0, …, L2, you could:

with torch.no_grad():
    L2_out_no_graph = L2 (L1 (L0 (input)))
L2_out_no_graph.requires_grad_ (True)
output = L5 (L4 (L3 (L2_out_no_graph)))
loss = loss_fn (output)
compute_grads = should_i_compute_grads_for_l012 (output)
loss.backward()    # populates grads for L3, L4, and L5, and L2_out_no_graph
if  compute_grads:
    L2_out = L2 (L1 (L0 (input)))   # build computation graph for L0, ..., L2 at the cost of a second forward pass
    torch.autograd.backward (L2_out, grad_tensors = L2_out_no_graph.grad)   # populates grads for L0, L1, and L2


This scheme avoids the cost of constructing the L0, …, L2 computation graph when it isn’t

needed (which we are assuming here is the usual case), but when you need that computation

graph, you not only incur the cost of building the graph, but also incur the duplicative cost of

recomputing L2 (L1 (L0 (input))).

Since you don’t know in advance – that is, before computing output – whether you will need

the L0, …, L2 computation graph, you can’t decide in advance which scheme to use. You

have to just pick one or the other (in advance) based on the likelihood that you will end up

needing the l0, …, l2 computation graph.

Best.

K. Frank