DataParallel for torch.distributions (MultivariateNormal)

I am using Dataparallel module to train my network. However, I am facing the gpu device error when the MultivariateNormal distribution module is used. with my tensors.

# Using DataParallel
prior = MultivariateNormal(torch.zeros(dim).to(device), torch.eye(dim).to(device))
model = Flow(dim, prior, n_block)
model = model.to(device)
if use_cuda and torch.cuda.device_count()>1:
     model = DataParallel(model, device_ids=[0, 1, 2, 3])

# Calculating log probability of my Multivariate distribution
logprob = prior.log_prob(z).view(x.size(0), -1).sum(1)
# Error
File "/data/saandeepaath/flow_based/modules/flows.py", line 95, in forward
logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1)
File "/home/saandeepaath/.local/lib/python3.7/site-packages/torch/distributions/multivariate_normal.py", line 207, in log_prob
    diff = value - self.loc
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Your code snippet is unfortunately not executable and doesn’t use the model at all.
Could you add the missing parts so that we could reproduce this issue and debug it, please?

Certainly. Sorry for missing out the details. Here is my full code. One strange thing is I was executing my code now to make sure I provided the right one and it worked without issues. I am certain the same code threw the error message with CUDA devices the last time I ran.

Below classes is for my model. The Flow() class creates the entire model.

import torch 
from torch import nn

from sys import exit as e

class SimpleNet(nn.Module):
  def __init__(self, inp, parity):
    super(SimpleNet, self).__init__()
    self.net = nn.Sequential(
      nn.Linear(inp//2, 256),
      nn.LeakyReLU(True),
      nn.Linear(256, 256),
      nn.LeakyReLU(True),
      nn.Linear(256, inp//2),
      nn.Sigmoid()
    )
    self.inp = inp
    self.parity = parity
  
  def forward(self, x):
    z = torch.zeros_like(x)
    x0, x1 = x[:, :, ::2, ::2], x[:, :, 1::2, 1::2]
    if self.parity % 2:
      x0, x1 = x1, x0 
    z1 = x1
    log_s = self.net(x1)
    t = self.net(x1)
    s = torch.exp(log_s)
    z0 = (s * x0) + t
    if self.parity%2:
      z0, z1 = z1, z0
    z[:, :, ::2, ::2] = z0
    z[:, :, 1::2, 1::2] = z1
    logdet = torch.sum(torch.log(s), dim = 1)
    return z, logdet
  
  def reverse(self, z):
    x = torch.zeros_like(z)
    z0, z1 = z[:, :, ::2, ::2], z[:, :, 1::2, 1::2]
    if self.parity%2:
      z0, z1 = z1, z0
    x1 = z1
    log_s = self.net(z1)
    t = self.net(z1)
    s = torch.exp(log_s)
    x0 = (z0 - t)/s
    if self.parity%2:
      x0, x1 = x1, x0
    x[:, :, ::2, ::2] = x0
    x[:, :, 1::2, 1::2] = x1
    return x
  

class Block(nn.Module):
  def __init__(self, inp, n_blocks):
    super(Block, self).__init__()
    parity = 0
    self.blocks = nn.ModuleList()
    for _ in range(n_blocks):
      self.blocks.append(SimpleNet(inp, parity))
      parity += 1
    
  
  def forward(self, x):
    logdet = 0
    out = x
    xs = [out]
    for block in self.blocks:
      out, det = block(out)
      logdet += det
      xs.append(out)
    return out, logdet

  def reverse(self, z):
    out = z
    for block in self.blocks[::-1]:
      out = block.reverse(out)
    return out


class Flow(nn.Module):
  def __init__(self, inp, prior, n_blocks):
    super(Flow, self).__init__()
    self.prior = prior
    self.flow = Block(inp, n_blocks)
  
  def forward(self, x):
    z, logdet = self.flow(x)
    logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1) #Error encountered here
    return z, logdet, logprob
  
  def reverse(self, z):
    x = self.flow.reverse(z)
    return x
  
  def get_sample(self, n):
    z = self.prior.sample(sample_shape = torch.Size([n]))
    return self.reverse(z)

I define my prior variable and instantiate the model as follows and train it on MNIST dataset

def startup(opt, use_cuda):
  torch.manual_seed(1)
  device = "cuda" if not opt.no_cuda and torch.cuda.is_available() else "cpu"
  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

  inpt = 100
  dim = 1
  img_size = 28
  n_block = 4
  epochs = 50
  lr = 0.01
  wd=1e-3
  old_loss = 1e6
  best_loss = 0
  batch_size = 128
  prior = MultivariateNormal(torch.zeros(img_size).to(device), torch.eye(img_size).to(device))

  #MNIST
  transform = transforms.Compose([transforms.ToTensor()])
  train_dataset = MNIST(root=opt.root, train=True, transform=transform, \
    download=True)
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
  
  model = Flow(dim, prior, n_block)
  model = model.to(device)

  if use_cuda and torch.cuda.device_count()>1:
        model = DataParallel(model, device_ids=[0, 1, 2, 3])

  optimizer = optim.Adam(model.parameters(), lr)
  scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
  
  t0 = time.time()
  for epoch in range(epochs):
    model.train()
    train_data(opt, model, device, train_loader, optimizer, epoch)
    scheduler.step()
  print(f"time to complete {epochs} epoch: {time.time() - t0} seconds")

Training is done the usual way with some preprocessing

def preprocess(x):
  x = x * 255
  x = torch.floor(x/2**3)
  x = x/32 - 0.5
  return x

for b, (x, _) in enumerate(train_loader):
    optimizer.zero_grad()
    x = x.to(device)
    x = preprocess(x)
    x = x.view(x.size(0), -1)
    z, logdet, logprob = model.module(x) # Encounters the original error

(I have skipped the rest of the steps as the program stops here)

This would make it quite impossible to debug. Let me know, if you could come up with a code snippet to reproduce this issue.

I identified the root cause. The main difference was how I was calling the forward method. I simply changed from

z, logdet, logprob = model.forward(x)

to

z, logdet, logprob = model.module.forward(x)

Below is my complete code to reproduce the issue. Line 2 of the forward method of the Flow class causes the error. I am exploring more on how calling forward with module resolves the issue but it would be great if you could help me understand and provide more clarity as it gets a little confusing for me to work with MultiVariateNormal and GPUs. Thank you so much.

ERROR

File "/flows.py", line 91, in forward
    logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1) #Error encountered here
  File "/home/saandeepaath/.local/lib/python3.7/site-packages/torch/distributions/multivariate_normal.py", line 207, in log_prob
    diff = value - self.loc
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
  1. flows.py (Model classes)
import torch 
from torch import nn

from sys import exit as e

class SimpleNet(nn.Module):
  def __init__(self, inp, parity):
    super(SimpleNet, self).__init__()
    self.net = nn.Sequential(
      nn.Conv2d(inp, 32, 3, 1, 1),
      nn.Tanh(),
      nn.Conv2d(32, 64, 3, 1, 1),
      nn.Tanh(),
      nn.Conv2d(64, 32, 3, 1, 1),
      nn.Tanh(),
      nn.Conv2d(32, inp, 3, 1, 1),
      nn.Tanh(),
    )
    self.inp = inp
    self.parity = parity
  
  def forward(self, x):
    z = torch.zeros_like(x)
    x0, x1 = x[:, :, ::2, ::2], x[:, :, 1::2, 1::2]
    if self.parity % 2:
      x0, x1 = x1, x0 
    z1 = x1
    log_s = self.net(x1)
    t = self.net(x1)
    s = torch.exp(log_s)
    z0 = (s * x0) + t
    if self.parity%2:
      z0, z1 = z1, z0
    z[:, :, ::2, ::2] = z0
    z[:, :, 1::2, 1::2] = z1
    logdet = torch.sum(torch.log(s), dim = 1)
    return z, logdet
  
  def reverse(self, z):
    x = torch.zeros_like(z)
    z0, z1 = z[:, :, ::2, ::2], z[:, :, 1::2, 1::2]
    if self.parity%2:
      z0, z1 = z1, z0
    x1 = z1
    log_s = self.net(z1)
    t = self.net(z1)
    s = torch.exp(log_s)
    x0 = (z0 - t)/s
    if self.parity%2:
      x0, x1 = x1, x0
    x[:, :, ::2, ::2] = x0
    x[:, :, 1::2, 1::2] = x1
    return x
  

class Block(nn.Module):
  def __init__(self, inp, n_blocks):
    super(Block, self).__init__()
    parity = 0
    self.blocks = nn.ModuleList()
    for _ in range(n_blocks):
      self.blocks.append(SimpleNet(inp, parity))
      parity += 1
    
  
  def forward(self, x):
    logdet = 0
    out = x
    xs = [out]
    for block in self.blocks:
      out, det = block(out)
      logdet += det
      xs.append(out)
    return out, logdet

  def reverse(self, z):
    out = z
    for block in self.blocks[::-1]:
      out = block.reverse(out)
    return out


class Flow(nn.Module):
  def __init__(self, inp, prior, n_blocks):
    super(Flow, self).__init__()
    self.prior = prior
    self.flow = Block(inp, n_blocks)
  
  def forward(self, x):
    z, logdet = self.flow(x)
    logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1) #Error encountered here
    return z, logdet, logprob
  
  def reverse(self, z):
    x = self.flow.reverse(z)
    return x
  
  def get_sample(self, n):
    z = self.prior.sample(sample_shape = torch.Size([n]))
    return self.reverse(z)
  1. Define multivariate distribution and instantiate the model
import time
import torch 
from torch import optim
from torch.distributions import MultivariateNormal
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn.parallel import  DataParallel

def startup(opt, use_cuda):
  torch.manual_seed(1)
  device = "cuda" if not opt.no_cuda and torch.cuda.is_available() else "cpu"
  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

  inpt = 100
  dim = 1
  img_size = 28
  n_block = 9
  epochs = 5
  lr = 0.01
  wd=1e-3
  old_loss = 1e6
  best_loss = 0
  batch_size = 128
  prior = MultivariateNormal(torch.zeros(img_size).to(device), torch.eye(img_size).to(device))

  #MNIST
  transform = transforms.Compose([transforms.ToTensor()])
  train_dataset = MNIST(root=opt.root, train=True, transform=transform, \
    download=True)
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
  
  model = Flow(dim, prior, n_block)
  model = model.to(device)

  if use_cuda and torch.cuda.device_count()>1:
    model = DataParallel(model, device_ids=[0, 1, 2, 3])

  optimizer = optim.Adam(model.parameters(), lr)
  scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
  
  t0 = time.time()
  for epoch in range(epochs):
    model.train()
    train_data(opt, model, device, train_loader, optimizer, epoch)
    scheduler.step()
  print(f"time to complete {epochs} epoch: {time.time() - t0} seconds")

  1. Training step
def preprocess(x):
  x = x * 255
  x = torch.floor(x/2**3)
  x = x/32 - 0.5
  return x

def train_data(opt, model, device, train_loader, optimizer, epoch):
  for b, (x, _) in enumerate(train_loader):
    optimizer.zero_grad()
    x = x.to(device)
    x = preprocess(x)
    z, logdet, logprob = model.forward(x) # Causes error. Change to model.module.forward(x) to resolve 

(rest of the step skipped as the above line is where my code throws error)

Thanks for the update.
The workaround is invalid, as you would call the underlying model on the default device and would skip the nn.DataParallel wrapper.

The device mismatch is raised, since torch.distributions do not have a to() method and are not registered as modules.
Here is a minimal code snippet to reproduce this issue:

class MyModel(nn.Module):
    def __init__(self, prior):
        super(MyModel, self).__init__()
        self.prior = prior

    def forward(self, x):
        y = self.prior.log_prob(x)
        return y

img_size = 1
device = 'cuda:0'
prior = torch.distributions.MultivariateNormal(torch.zeros(img_size).to(device), torch.eye(img_size).to(device))

model = MyModel(prior)
print(model.prior)
model.to('cuda:1')
print(model.prior)

x = torch.randn(1).to('cuda:1')
out = model(x)
print(out)

As you can see, even after calling model.to('cuda:1') model.prior is still on cuda:0 and the forward pass will raise the same error.

It also seems to be a known issue so feel free to add your use case to it.
I wanted to suggest the same workaround from this post, i.e. to register the loc and scale as buffers and recreate the distribution in the forward method.

2 Likes

Thank you for your response. I was able to bypass this using the solution from this post.