Does a call to something like
loss.backwards() cause the CPU code to block?
5 get printed concurrently with or after
loss.backwards() is finished executing?
I tried running the following example to test it myself:
import torch from torch import nn print('net') net = nn.Linear(100000, 1) print('loss') loss = net(torch.ones(10000, 100000)).mean() print('back') loss.backward() print(5)
It’s hard to tell for me. Seems like at least the net creation is blocking.
Edit: I would also like to know the logic of this when the model is on GPU.
loss = net.to('cuda')(torch.ones(10000, 100000).to('cuda')).mean()