I have the following code snippet:
class Model(nn.Module): # omitting __init__ # returns a scalar that can run backward on def expensive_call(self, x): # omiting detailed definition return some_scalar def forward(self, x): self.expensive_call(x).backward() m = nn.DataParallel(Model().cuda()) # Is this thread-safe? m(x) # calling forward on input x, expecting a parallel backward
The motivation: I don’t care about the forward call return value and I specifically don’t want to gather the results back to device 0 and then do the backward. So I wrap the detailed logic into
expensive_call and let
DataParallel run backward inside forward for me (in a parallel fashion).
Despite the uncommon syntax, does this code actually work / is bug-free? I roughly know that GPU events are launched to cuda streams, and also there are some multithreading code under the hood of
DataParallel. Are these cuda streams thread-safe?