Does a call to something like loss.backwards()
cause the CPU code to block?
loss.backwards()
print(5)
Does 5
get printed concurrently with or after loss.backwards()
is finished executing?
I tried running the following example to test it myself:
import torch
from torch import nn
print('net')
net = nn.Linear(100000, 1)
print('loss')
loss = net(torch.ones(10000, 100000)).mean()
print('back')
loss.backward()
print(5)
It’s hard to tell for me. Seems like at least the net creation is blocking.
Edit: I would also like to know the logic of this when the model is on GPU.
loss = net.to('cuda')(torch.ones(10000, 100000).to('cuda')).mean()