Output switches to CPU after passing though ModuleList blocks

I am training a model using DistributedDataParallel(code snippet below). My model is initialized using nn.ModuleList(). But once the input which is on GPU passes through one of the blocks in nn.ModuleList, it switches to CPU mode.

if use_cuda and torch.cuda.device_count()>1:
    model = model.to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])

Please refer to forward method of Block class which creates a list of SimpleNet class. Please let me know if you require additional piece of code.

class SimpleNet(nn.Module):
  def __init__(self, inp, parity):
    super(SimpleNet, self).__init__()
    self.net = nn.Sequential(
      nn.Linear(inp//2, 256),
      nn.Linear(256, 256),
      nn.Linear(256, inp//2),
    self.inp = inp
    self.parity = parity

  def forward(self, x):
    z = torch.zeros(x.size())
    x0, x1 = x[:, ::2], x[:, 1::2]
    if self.parity % 2:
      x0, x1 = x1, x0 
    # print("X: ", x0[0][0].detach(), x1[0][0].detach())
    z1 = x1
    log_s = self.net(x1)
    # print(x.size(), x1.size(), log_s.size())
    t = self.net(x1)
    s = torch.exp(log_s)
    z0 = (s * x0) + t
    # print("Z: ", z0[0][0].detach(), z1[0][0].detach())
    if self.parity%2:
      z0, z1 = z1, z0
    z[:, ::2] = z0
    z[:, 1::2] = z1
    logdet = torch.sum(torch.log(s), dim = 1)
    return z, logdet
  def reverse(self, z):
    x = torch.zeros(z.size())
    z0, z1 = z[:, ::2], z[:, 1::2]
    if self.parity%2:
      z0, z1 = z1, z0
    # print("Z: ", z0[0][0].detach(), z1[0][0].detach())
    x1 = z1
    log_s = self.net(z1)
    t = self.net(z1)
    s = torch.exp(log_s)
    x0 = (z0 - t)/s
    # print("X: ", x0[0][0].detach(), x1[0][0].detach())
    if self.parity%2:
      x0, x1 = x1, x0
    x[:, ::2] = x0
    x[:, 1::2] = x1
    return x

class Block(nn.Module):
  def __init__(self, inp, n_blocks):
    super(Block, self).__init__()
    parity = 0
    blocks = []
    for _ in range(n_blocks):
      blocks.append(SimpleNet(inp, parity))
      parity += 1
    self.blocks = nn.ModuleList(blocks) 
  def forward(self, x):
    logdet = 0
    out = x
    xs = [out]
    # print("*"*20, "FORWARD", "*"*30)
    for block in self.blocks:
      print("device in block: ", out.is_cuda) # True
      out, det = block(out)
      print("device in block: ", out.is_cuda) # False
      logdet += det
    return out, logdet

  def reverse(self, z):
    # print("*"*20, "REVERSE", "*"*30)
    out = z
    for block in self.blocks[::-1]:
      out = block.reverse(out)
    return out

In the SimpleNet forward() function, the zeros tensor z is being created on CPU, which is why the out tensor returned by the forward function of that block returns out.is_cuda = False. You must either place the z tensor explicitly on the correct rank or make it an nn.Parameter or similar type that will be moved to GPU when the entire module is placed on GPU.

Thank you for your response. I was able to make it work using the nn.Parameter and register_parameter but I noticed that if I just initialize the z tensor like z = torch.zeros_like(x) instead of torch.zeros(x.size()), it is automatically loaded to the same device as x.

1 Like