I identified the root cause. The main difference was how I was calling the forward method. I simply changed from
z, logdet, logprob = model.forward(x)
to
z, logdet, logprob = model.module.forward(x)
Below is my complete code to reproduce the issue. Line 2 of the forward method of the Flow class causes the error. I am exploring more on how calling forward with module resolves the issue but it would be great if you could help me understand and provide more clarity as it gets a little confusing for me to work with MultiVariateNormal and GPUs. Thank you so much.
ERROR
File "/flows.py", line 91, in forward
logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1) #Error encountered here
File "/home/saandeepaath/.local/lib/python3.7/site-packages/torch/distributions/multivariate_normal.py", line 207, in log_prob
diff = value - self.loc
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
- flows.py (Model classes)
import torch
from torch import nn
from sys import exit as e
class SimpleNet(nn.Module):
def __init__(self, inp, parity):
super(SimpleNet, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(inp, 32, 3, 1, 1),
nn.Tanh(),
nn.Conv2d(32, 64, 3, 1, 1),
nn.Tanh(),
nn.Conv2d(64, 32, 3, 1, 1),
nn.Tanh(),
nn.Conv2d(32, inp, 3, 1, 1),
nn.Tanh(),
)
self.inp = inp
self.parity = parity
def forward(self, x):
z = torch.zeros_like(x)
x0, x1 = x[:, :, ::2, ::2], x[:, :, 1::2, 1::2]
if self.parity % 2:
x0, x1 = x1, x0
z1 = x1
log_s = self.net(x1)
t = self.net(x1)
s = torch.exp(log_s)
z0 = (s * x0) + t
if self.parity%2:
z0, z1 = z1, z0
z[:, :, ::2, ::2] = z0
z[:, :, 1::2, 1::2] = z1
logdet = torch.sum(torch.log(s), dim = 1)
return z, logdet
def reverse(self, z):
x = torch.zeros_like(z)
z0, z1 = z[:, :, ::2, ::2], z[:, :, 1::2, 1::2]
if self.parity%2:
z0, z1 = z1, z0
x1 = z1
log_s = self.net(z1)
t = self.net(z1)
s = torch.exp(log_s)
x0 = (z0 - t)/s
if self.parity%2:
x0, x1 = x1, x0
x[:, :, ::2, ::2] = x0
x[:, :, 1::2, 1::2] = x1
return x
class Block(nn.Module):
def __init__(self, inp, n_blocks):
super(Block, self).__init__()
parity = 0
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(SimpleNet(inp, parity))
parity += 1
def forward(self, x):
logdet = 0
out = x
xs = [out]
for block in self.blocks:
out, det = block(out)
logdet += det
xs.append(out)
return out, logdet
def reverse(self, z):
out = z
for block in self.blocks[::-1]:
out = block.reverse(out)
return out
class Flow(nn.Module):
def __init__(self, inp, prior, n_blocks):
super(Flow, self).__init__()
self.prior = prior
self.flow = Block(inp, n_blocks)
def forward(self, x):
z, logdet = self.flow(x)
logprob = self.prior.log_prob(z).view(x.size(0), -1).sum(1) #Error encountered here
return z, logdet, logprob
def reverse(self, z):
x = self.flow.reverse(z)
return x
def get_sample(self, n):
z = self.prior.sample(sample_shape = torch.Size([n]))
return self.reverse(z)
- Define multivariate distribution and instantiate the model
import time
import torch
from torch import optim
from torch.distributions import MultivariateNormal
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn.parallel import DataParallel
def startup(opt, use_cuda):
torch.manual_seed(1)
device = "cuda" if not opt.no_cuda and torch.cuda.is_available() else "cpu"
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
inpt = 100
dim = 1
img_size = 28
n_block = 9
epochs = 5
lr = 0.01
wd=1e-3
old_loss = 1e6
best_loss = 0
batch_size = 128
prior = MultivariateNormal(torch.zeros(img_size).to(device), torch.eye(img_size).to(device))
#MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST(root=opt.root, train=True, transform=transform, \
download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
model = Flow(dim, prior, n_block)
model = model.to(device)
if use_cuda and torch.cuda.device_count()>1:
model = DataParallel(model, device_ids=[0, 1, 2, 3])
optimizer = optim.Adam(model.parameters(), lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
t0 = time.time()
for epoch in range(epochs):
model.train()
train_data(opt, model, device, train_loader, optimizer, epoch)
scheduler.step()
print(f"time to complete {epochs} epoch: {time.time() - t0} seconds")
- Training step
def preprocess(x):
x = x * 255
x = torch.floor(x/2**3)
x = x/32 - 0.5
return x
def train_data(opt, model, device, train_loader, optimizer, epoch):
for b, (x, _) in enumerate(train_loader):
optimizer.zero_grad()
x = x.to(device)
x = preprocess(x)
z, logdet, logprob = model.forward(x) # Causes error. Change to model.module.forward(x) to resolve
(rest of the step skipped as the above line is where my code throws error)