Freezing layers when using torch.compile


I am finetuning Wav2Vec by placing an MLP on top of its representations. I am doing this by freezing Wav2Vec for the first few epochs of training and then unfreezing it. I implement this by setting requires_grad=False for all Wav2Vec parameters. I would also like to use torch.compile to speed up training. When I try to do this, after unfreezing the Wav2Vec module (i.e. setting requires_grad=True) I get the following error when I pass inputs into my model:

RuntimeError: addmm(): functions with out=… arguments don’t support automatic differentiation, but one of the arguments requires grad.

I am pretty confused about this but I am thinking that the problem is that torch.config is building the computational graph such that the the Wav2Vec parameters don’t require gradient calculations and is hence unable to handle the epochs when I unfreeze the model. My question is how I should go about handling this - should I compile the model twice (e.g. once for the training epochs where the Wav2Vec layer is frozen and then re-compile when I unfreeze it) or am I missing something?

I am sorry if this is a silly question but I am really struggling to figure it out

Thanks in advance if anyone can point me towards anything that helps

That sounds like a bug with torch.compile.

(1) what version of PyTorch are you using? Can you try running your example with a nightly and see if you still get the same error? (There have generally been a lot of bug fixes over the last few months).

(2) if you still get the error on a nightly, do you have a reproducer script that shows the issue? That would be very helpful, you can file an issue with the repro on GitHub.

Yep, using a new version of PyTorch seems to fix the issue. I was on 2.1 and running in 2.2 seems to fix the problem. Thanks!