I identified the root cause. The main difference was how I was calling the forward method. I simply changed from

```
z, logdet, logprob = model.forward(x)
```

to

```
z, logdet, logprob = model.module.forward(x)
```

Below is my complete code to reproduce the issue. Line 2 of the *forward* method of the *Flow* class causes the error. *I am exploring more on how calling forward with module resolves the issue but it would be great if you could help me understand and provide more clarity as it gets a little confusing for me to work with MultiVariateNormal and GPUs.* Thank you so much.

ERROR

```
File "/flows.py", line 91, in forward
logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1) #Error encountered here
File "/home/saandeepaath/.local/lib/python3.7/site-packages/torch/distributions/multivariate_normal.py", line 207, in log_prob
diff = value - self.loc
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
```

- flows.py (Model classes)

```
import torch
from torch import nn
from sys import exit as e
class SimpleNet(nn.Module):
def __init__(self, inp, parity):
super(SimpleNet, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(inp, 32, 3, 1, 1),
nn.Tanh(),
nn.Conv2d(32, 64, 3, 1, 1),
nn.Tanh(),
nn.Conv2d(64, 32, 3, 1, 1),
nn.Tanh(),
nn.Conv2d(32, inp, 3, 1, 1),
nn.Tanh(),
)
self.inp = inp
self.parity = parity
def forward(self, x):
z = torch.zeros_like(x)
x0, x1 = x[:, :, ::2, ::2], x[:, :, 1::2, 1::2]
if self.parity % 2:
x0, x1 = x1, x0
z1 = x1
log_s = self.net(x1)
t = self.net(x1)
s = torch.exp(log_s)
z0 = (s * x0) + t
if self.parity%2:
z0, z1 = z1, z0
z[:, :, ::2, ::2] = z0
z[:, :, 1::2, 1::2] = z1
logdet = torch.sum(torch.log(s), dim = 1)
return z, logdet
def reverse(self, z):
x = torch.zeros_like(z)
z0, z1 = z[:, :, ::2, ::2], z[:, :, 1::2, 1::2]
if self.parity%2:
z0, z1 = z1, z0
x1 = z1
log_s = self.net(z1)
t = self.net(z1)
s = torch.exp(log_s)
x0 = (z0 - t)/s
if self.parity%2:
x0, x1 = x1, x0
x[:, :, ::2, ::2] = x0
x[:, :, 1::2, 1::2] = x1
return x
class Block(nn.Module):
def __init__(self, inp, n_blocks):
super(Block, self).__init__()
parity = 0
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(SimpleNet(inp, parity))
parity += 1
def forward(self, x):
logdet = 0
out = x
xs = [out]
for block in self.blocks:
out, det = block(out)
logdet += det
xs.append(out)
return out, logdet
def reverse(self, z):
out = z
for block in self.blocks[::-1]:
out = block.reverse(out)
return out
class Flow(nn.Module):
def __init__(self, inp, prior, n_blocks):
super(Flow, self).__init__()
self.prior = prior
self.flow = Block(inp, n_blocks)
def forward(self, x):
z, logdet = self.flow(x)
logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1) #Error encountered here
return z, logdet, logprob
def reverse(self, z):
x = self.flow.reverse(z)
return x
def get_sample(self, n):
z = self.prior.sample(sample_shape = torch.Size([n]))
return self.reverse(z)
```

- Define multivariate distribution and instantiate the model

```
import time
import torch
from torch import optim
from torch.distributions import MultivariateNormal
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn.parallel import DataParallel
def startup(opt, use_cuda):
torch.manual_seed(1)
device = "cuda" if not opt.no_cuda and torch.cuda.is_available() else "cpu"
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
inpt = 100
dim = 1
img_size = 28
n_block = 9
epochs = 5
lr = 0.01
wd=1e-3
old_loss = 1e6
best_loss = 0
batch_size = 128
prior = MultivariateNormal(torch.zeros(img_size).to(device), torch.eye(img_size).to(device))
#MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST(root=opt.root, train=True, transform=transform, \
download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
model = Flow(dim, prior, n_block)
model = model.to(device)
if use_cuda and torch.cuda.device_count()>1:
model = DataParallel(model, device_ids=[0, 1, 2, 3])
optimizer = optim.Adam(model.parameters(), lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
t0 = time.time()
for epoch in range(epochs):
model.train()
train_data(opt, model, device, train_loader, optimizer, epoch)
scheduler.step()
print(f"time to complete {epochs} epoch: {time.time() - t0} seconds")
```

- Training step

```
def preprocess(x):
x = x * 255
x = torch.floor(x/2**3)
x = x/32 - 0.5
return x
def train_data(opt, model, device, train_loader, optimizer, epoch):
for b, (x, _) in enumerate(train_loader):
optimizer.zero_grad()
x = x.to(device)
x = preprocess(x)
z, logdet, logprob = model.forward(x) # Causes error. Change to model.module.forward(x) to resolve
(rest of the step skipped as the above line is where my code throws error)
```