After some digging I found out this thread:
which describes exactly the same issue. I used the solution in my code and the training works properly now! Although I kinda agree that this is more like bug rather than missed practice, as torch.no_grad()
can happen anywhere inside the model while we typically wrap the entire forward pass at the most outter layer. Also, I was using Huggingface’s accelerate launch tool which achieves the mixed precision training by “preparing” the entire model too so it took me a long time to ping point the torch.no_grad()
part that happened somewhere inside my model.