Hey everyone,
I have a model spread across a couple of GPUs:
class MicroUNet3D(nn.Module):
def __init__(self, n_channels, n_classes):
super(MicroUNet3D, self).__init__()
self.inconv = InConv(n_channels, 2).to('cuda:0')
self.down1 = Down(2, 4).to('cuda:0')
self.down2 = Down(4, 8).to('cuda:0')
self.up1 = Up(8, 4).to('cuda:1')
self.up2 = Up(4, 2).to('cuda:1')
self.outconv = OutConv(2, n_classes).to('cuda:1')
def forward(self, x):
x1 = self.inconv(x)
x2, indices1 = self.down1(x1)
x3, indices2 = self.down2(x2)
# Transfer to next GPU.
x2, indices1 = x2.to('cuda:1'), indices1.to('cuda:1')
x3, indices2 = x3.to('cuda:1'), indices2.to('cuda:1')
x4 = self.up1(x3, indices2, x2.shape)
x5 = self.up2(x4, indices1, x1.shape)
x6 = self.outconv(x5)
return x6
Is there a way to determine how the communication is being handled with the to()
method? I am hoping that Pytorch will use NCCL here, and I would like to make sure.