Can gradients be computed on tensors obtained from a foward hook?

Hi, I’m trying to see if I can finetune a pretrained vision model for a downstream task using intermediate activations as features. Since it’s cumbersome to manually save activations by editing model definitions to output multiple tensors, I was trying to do this by just adding hooks to save activations from relevant layers.

As a litmus test, I wanted to see if I could train a model by grabbing the last layer’s output from a hook rather than directly from the model’s output:

def save_act(name, mod, inp, out):

and using the saved output from the linear layer to try and compute the gradients:

for name, m in model.named_modules():
    if type(m) == nn.Linear:
        m.register_forward_hook(partial(save_act, name))
output = activations['module.fc'][0]
loss = criterion(output, target)

At first, after trying this, I get:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

But after adding retain_graph=True to loss.backward(), I get

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 196]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

which makes it seem like the Tensor that was saved is somehow stale or different from the one I should be using. Is there another way/fix for computing gradients on tensors obtained in a similar way?


Could you try to assign the out tensor to activations[name] instead of appending it?
If I understand your use case correctly, you would only need to current output not all outputs from previous iterations (which might cause this error, if their computation graph was already freed).

1 Like

Yes, you were right, it turns out using the first tensor in the list over and over was the problem and not even the actual intent. Thanks!