I am training a model using DistributedDataParallel(code snippet below). My model is initialized using nn.ModuleList(). But once the input which is on GPU passes through one of the blocks in nn.ModuleList, it switches to CPU mode.
if use_cuda and torch.cuda.device_count()>1:
model = model.to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
Please refer to forward method of Block class which creates a list of SimpleNet class. Please let me know if you require additional piece of code.
class SimpleNet(nn.Module):
def __init__(self, inp, parity):
super(SimpleNet, self).__init__()
self.net = nn.Sequential(
nn.Linear(inp//2, 256),
nn.LeakyReLU(True),
nn.Linear(256, 256),
nn.LeakyReLU(True),
nn.Linear(256, inp//2),
nn.Sigmoid(),
nn.BatchNorm1d(392)
)
self.inp = inp
self.parity = parity
def forward(self, x):
z = torch.zeros(x.size())
x0, x1 = x[:, ::2], x[:, 1::2]
if self.parity % 2:
x0, x1 = x1, x0
# print("X: ", x0[0][0].detach(), x1[0][0].detach())
z1 = x1
log_s = self.net(x1)
# print(x.size(), x1.size(), log_s.size())
t = self.net(x1)
s = torch.exp(log_s)
z0 = (s * x0) + t
# print("Z: ", z0[0][0].detach(), z1[0][0].detach())
if self.parity%2:
z0, z1 = z1, z0
z[:, ::2] = z0
z[:, 1::2] = z1
logdet = torch.sum(torch.log(s), dim = 1)
return z, logdet
def reverse(self, z):
x = torch.zeros(z.size())
z0, z1 = z[:, ::2], z[:, 1::2]
if self.parity%2:
z0, z1 = z1, z0
# print("Z: ", z0[0][0].detach(), z1[0][0].detach())
x1 = z1
log_s = self.net(z1)
t = self.net(z1)
s = torch.exp(log_s)
x0 = (z0 - t)/s
# print("X: ", x0[0][0].detach(), x1[0][0].detach())
if self.parity%2:
x0, x1 = x1, x0
x[:, ::2] = x0
x[:, 1::2] = x1
return x
class Block(nn.Module):
def __init__(self, inp, n_blocks):
super(Block, self).__init__()
parity = 0
blocks = []
for _ in range(n_blocks):
blocks.append(SimpleNet(inp, parity))
parity += 1
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
logdet = 0
out = x
xs = [out]
# print("*"*20, "FORWARD", "*"*30)
for block in self.blocks:
print("device in block: ", out.is_cuda) # True
out, det = block(out)
print("device in block: ", out.is_cuda) # False
logdet += det
xs.append(out)
return out, logdet
def reverse(self, z):
# print("*"*20, "REVERSE", "*"*30)
out = z
for block in self.blocks[::-1]:
out = block.reverse(out)
return out