I have the following code snippet:
class Model(nn.Module):
# omitting __init__
# returns a scalar that can run backward on
def expensive_call(self, x):
# omiting detailed definition
return some_scalar
def forward(self, x):
self.expensive_call(x).backward()
m = nn.DataParallel(Model().cuda())
# Is this thread-safe?
m(x) # calling forward on input x, expecting a parallel backward
The motivation: I don’t care about the forward call return value and I specifically don’t want to gather the results back to device 0 and then do the backward. So I wrap the detailed logic into expensive_call
and let DataParallel
run backward inside forward for me (in a parallel fashion).
Despite the uncommon syntax, does this code actually work / is bug-free? I roughly know that GPU events are launched to cuda streams, and also there are some multithreading code under the hood of DataParallel
. Are these cuda streams thread-safe?