Is this code snippet thread safe with DataParallel?

I have the following code snippet:


class Model(nn.Module):
  # omitting __init__

  # returns a scalar that can run backward on
  def expensive_call(self, x):
    # omiting detailed definition
    return some_scalar

  def forward(self, x):
    self.expensive_call(x).backward()

m = nn.DataParallel(Model().cuda())

# Is this thread-safe?
m(x)  # calling forward on input x, expecting a parallel backward

The motivation: I don’t care about the forward call return value and I specifically don’t want to gather the results back to device 0 and then do the backward. So I wrap the detailed logic into expensive_call and let DataParallel run backward inside forward for me (in a parallel fashion).

Despite the uncommon syntax, does this code actually work / is bug-free? I roughly know that GPU events are launched to cuda streams, and also there are some multithreading code under the hood of DataParallel. Are these cuda streams thread-safe?