Hi there,
I would like to use torch.nn.DataParallel
for intermediate layers in my model. the structure looks like following:
class MyModel(nn.Module):
def __init__(self):
... <some other layers>
self.linear = nn.DataParallel(nn.Linear(100,100))
self.linear2 = nn.DataParallel(nn.Linear(50,500))
def forward(self, x):
x.to("cuda:0") # assume moved to gpu-0
... <some other operation on inputs>
x = self.linear(x)
x = self.linear2(x)
...
return x
While I have several concerns:
- the output of the
nn.DataParallel
by default is the gpu-0, will it cause extract overhead? because of the memory copy among different GPUs, assume we have 8 GPUs for Data parallel. - When we calculate the loss, if the loss value is calculated on GPU-0, how does intermediate
nn.DataParallel
layer do the right backward pass? - does pytorch allow cross device backward pass? (e.g. I have a tensor
A
atcuda:0
, later I copiedA
and moved it tocuda:1
,B = A; A = A.to('cuda:1')
, later when I do backward on variableA
, will the gradient computation propagate toB
atcuda:0
?