Issue with pytorch tensors and multiple GPUs when using DataParallel

I have a large ML code that I’ve been writing for a few months and I’ve started the process of try to parallelize the data side of things to work with multiple GPUs. To start, the code works perfectly when using a single GPU; the issue comes when using multiple GPUs.

The error is as follows: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Snippets of relevant code can be found below:

Model file

import torch
from torch import nn
from functools import partial
import copy

from ..mlp import MLP
from ..basis import gaussian, bessel
from ..conv import GatedGCN


class Encoder(nn.Module):
    """ALIGNN/ALIGNN-d Encoder.
    The encoder must take a PyG graph object `data` and output the same `data`
    with additional fields `h_atm`, `h_bnd`, and `h_ang` that correspond to the atom, bond, and angle embedding.

    The input `data` must have three fields `x_atm`, `x_bnd`, and `x_ang` that describe the atom type
    (in onehot vectors), the bond lengths, and bond/dihedral angles (in radians).
    """

    def __init__(self, num_species, cutoff, dim=128, dihedral=False):
        super().__init__()
        self.num_species = num_species
        self.cutoff = cutoff
        self.dim = dim
        self.dihedral = dihedral

        self.embed_atm = nn.Sequential(MLP([num_species, dim, dim], act=nn.SiLU()), nn.LayerNorm(dim))
        self.embed_bnd = partial(bessel, start=0, end=cutoff, num_basis=dim)
        self.embed_ang = self.embed_ang_with_dihedral if dihedral else self.embed_ang_without_dihedral

    def embed_ang_with_dihedral(self, x_ang, mask_dih_ang):
        cos_ang = torch.cos(x_ang)
        sin_ang = torch.sin(x_ang)

        h_ang = torch.zeros([len(x_ang), self.dim], device=x_ang.device)
        h_ang[~mask_dih_ang, :self.dim // 2] = gaussian(cos_ang[~mask_dih_ang], start=-1, end=1,
                                                        num_basis=self.dim // 2)

        h_cos_ang = gaussian(cos_ang[mask_dih_ang], start=-1, end=1, num_basis=self.dim // 4)
        h_sin_ang = gaussian(sin_ang[mask_dih_ang], start=-1, end=1, num_basis=self.dim // 4)
        h_ang[mask_dih_ang, self.dim // 2:] = torch.cat([h_cos_ang, h_sin_ang], dim=-1)

        return h_ang

    def embed_ang_without_dihedral(self, x_ang, mask_dih_ang):
        cos_ang = torch.cos(x_ang)
        return gaussian(cos_ang, start=-1, end=1, num_basis=self.dim)

    def forward(self, data):
        # Embed atoms
        data.h_atm = self.embed_atm(data.x_atm)

        # Embed bonds
        data.h_bnd = self.embed_bnd(data.x_bnd)

        # Embed angles
        data.h_ang = self.embed_ang(data.x_ang, data.mask_dih_ang)

        return data


class Processor(nn.Module):
    """ALIGNN Processor.
    The processor updates atom, bond, and angle embeddings.
    """

    def __init__(self, num_convs, dim):
        super().__init__()
        self.num_convs = num_convs
        self.dim = dim

        self.atm_bnd_convs = nn.ModuleList([copy.deepcopy(GatedGCN(dim, dim)) for _ in range(num_convs)])
        self.bnd_ang_convs = nn.ModuleList([copy.deepcopy(GatedGCN(dim, dim)) for _ in range(num_convs)])

    def forward(self, data):
        edge_index_G = data.edge_index_G
        edge_index_A = data.edge_index_A

        for i in range(self.num_convs):
            data.h_bnd, data.h_ang = self.bnd_ang_convs[i](data.h_bnd, edge_index_A, data.h_ang)
            data.h_atm, data.h_bnd = self.atm_bnd_convs[i](data.h_atm, edge_index_G, data.h_bnd)

        return data


class Decoder(nn.Module):
    def __init__(self, node_dim, out_dim):
        super().__init__()
        self.node_dim = node_dim
        self.out_dim = out_dim
        self.decoder = MLP([node_dim, node_dim, out_dim], act=nn.SiLU())

    def forward(self, data):
        return self.decoder(data.h_atm)


class ALIGNN(nn.Module):
    """ALIGNN model.
    Can optinally encode dihedral angles.
    """

    def __init__(self, encoder, processor, decoder):
        super().__init__()
        self.encoder = encoder
        self.processor = processor
        self.decoder = decoder

    def forward(self, data):
        data = self.encoder(data)
        data = self.processor(data)
        return self.decoder(data)

Training file

from tqdm.notebook import trange
from datetime import datetime
import glob
import sys
import os

def train(loader,model,parameters,PIN_MEMORY=False):
    model.train()
    total_loss = 0.0
    model = nn.DataParallel(model, device_ids=[0, 1]).cuda()
    #model = model.to(parameters['device'])
    optimizer = torch.optim.AdamW(model.module.processor.parameters(), lr=parameters['LEARN_RATE'])

    #model = model.to(parameters['device'])
    loss_fn = torch.nn.MSELoss()
    for i,data in enumerate(loader, 0):
        optimizer.zero_grad(set_to_none=True)
        #data = data.to(parameters['device'], non_blocking=PIN_MEMORY)
        data = data.cuda()
        #encoding = model.encoder(data)
        #proc = model.processor(encoding.module)
        #atom_contrib, bond_contrib, angle_contrib = model.decoder(proc.module)
        atom_contrib, bond_contrib, angle_contrib = model(data)

        all_sum = atom_contrib.sum() + bond_contrib.sum() + angle_contrib.sum()

        loss = loss_fn(all_sum, data.y[0][0])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def run_training(data,parameters,model):
    follow_batch = ['x_atm', 'x_bnd', 'x_ang'] if hasattr(data['training'][0], 'x_ang') else ['x_atm']
    loader_train = DataLoader(data['training'], batch_size=parameters['BATCH_SIZE'], shuffle=True, follow_batch=follow_batch)
    loader_valid = DataLoader(data['validation'], batch_size=parameters['BATCH_SIZE'], shuffle=False)

    L_train, L_valid = [], []
    min_loss_train = 1.0E30
    min_loss_valid = 1.0E30

    stats_file = open(os.path.join(parameters['model_dir'],'loss.data'),'w')
    stats_file.write('Training_loss     Validation loss\n')
    stats_file.close()
    for ep in range(parameters['num_epochs']):
        stats_file = open(os.path.join(parameters['model_dir'], 'loss.data'), 'a')
        print('Epoch ',ep,' of ',parameters['num_epochs'])
        sys.stdout.flush()
        loss_train = train(loader_train, model, parameters);
        L_train.append(loss_train)
        loss_valid = test_non_intepretable(loader_valid, model, parameters)
        L_valid.append(loss_valid)
        stats_file.write(str(loss_train) + '     ' + str(loss_valid) + '\n')
        if loss_train < min_loss_train:
            min_loss_train = loss_train
            if loss_valid < min_loss_valid:
                min_loss_valid = loss_valid
                if parameters['remove_old_model']:
                    model_name = glob.glob(os.path.join(parameters['model_dir'], 'model_*'))
                    if len(model_name) > 0:
                        os.remove(model_name[0])
                now = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
                print('Min train loss: ', min_loss_train, ' min valid loss: ', min_loss_valid, ' time: ', now)
                torch.save(model.state_dict(), os.path.join(parameters['model_dir'], 'model_' + str(now)))
        stats_file.close()
        if loss_train < parameters['train_tolerance'] and loss_valid < parameters['train_tolerance']:
            print('Validation and training losses satisy set tolerance...exiting training loop...')
            break

There are many other files but I think these are the relevant ones for this specific problem, but I’m happy to include for code if needed. My data is stored as a graph and batched by a DataLoader and looks like this (when batched two at a time):

Graph_DataBatch(atoms=[2], edge_index_G=[2, 89966], edge_index_A=[2, 1479258], x_atm=[5184, 5], x_atm_batch=[5184], x_atm_ptr=[3], x_bnd=[89966], x_bnd_batch=[89966], x_bnd_ptr=[3], x_ang=[1479258], x_ang_batch=[1479258], x_ang_ptr=[3], mask_dih_ang=[1479258], atm_amounts=[6], bnd_amounts=[6], ang_amounts=[6], y=[179932, 1])

The problematic line is in the Encoder during its forward function:

def forward(self, data):
        # Embed atoms
        data.h_atm = self.embed_atm(data.x_atm)

        # Embed bonds
        data.h_bnd = self.embed_bnd(data.x_bnd)

        # Embed angles
        data.h_ang = self.embed_ang(data.x_ang, data.mask_dih_ang)

        return data

Specifically, data.h_atm = self.embed_atm(data.x_atm). To briefly explain what I’m trying to do, I am loading in a bunch of graphs into the Dataloader, which are batched and then fed into the model for training:

def train(loader,model,parameters,PIN_MEMORY=False):
    model.train()
    total_loss = 0.0
    model = nn.DataParallel(model, device_ids=[0, 1]).cuda()
    #model = model.to(parameters['device'])
    optimizer = torch.optim.AdamW(model.module.processor.parameters(), lr=parameters['LEARN_RATE'])

    #model = model.to(parameters['device'])
    loss_fn = torch.nn.MSELoss()
    for i,data in enumerate(loader, 0):
        optimizer.zero_grad(set_to_none=True)
        #data = data.to(parameters['device'], non_blocking=PIN_MEMORY)
        data = data.cuda()
        #encoding = model.encoder(data)
        #proc = model.processor(encoding.module)
        #atom_contrib, bond_contrib, angle_contrib = model.decoder(proc.module)
        atom_contrib, bond_contrib, angle_contrib = model(data)

My understanding is that I have sent my batched data to the GPU and my model parameters are on the GPU and DataParallel should take care of splitting my data up and sending everything to each GPU automatically.

My question can be broken into a few parts: (1) Is this understanding correct? (2) Does my code actually seem like its doing this, and (3) does this error have anything to do with that, and if not, what is this error trying to tell me? I don’t expect anyone to fix my code for me, but I would like to understand why this error is happening, because I think I’m misunderstanding the underlying logic of how DataParallel is taking my data and sending it to the GPU. I’m happy to provide any details you might need to better understand this problem.

I have tried to better understand the line that breaks: data.h_atm = self.embed_atm(data.x_atm) by printing out where data.x_atm actually is when inside of the forward function, which should be after DataParallel has partitioned the data and I get this for all tensors:

tensor([[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
...,
[0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 1., 0., 0.]], device='cuda:0')

I think this telling me that all of my data is on GPU 0 despite the model being on both GPU 0 and 1 (when running on 2 GPUs), which I’ve confirmed by using nvidia-smi and observing both GPUs at about half memory consumption. I have also tried various combinations of calling X.to_device('cuda') or X.cuda(), where X is my graph data, but nothing seems to make any difference to the tensor print out.

These issues are often caused by explicit data movement added to the forward methods of your modules. I don’t see anything obviously wrong, but check if you are moving any tensor to e.g. the default device.

If you cannot narrow down which input is on the wrong device, add debug print statements showing the device of your input data as well as the parameters.

Yes, this is correct. DataParallel will clone the model onto each device in the forward pass and will split the inputs sending the chunks to the corresponding device. This adds a lot of overhead, which is why we generally recommend using DistributedDataParallel.

Yes, this could be the case. I don’t know what data is, but in case it’s not a plain tensor, I guess their internal data won’t be moved to the right device. You could try to move it explicitly by using any parameters, e.g.:

data.h_atm = data.h_atm.t(next(self.parameters()).device)

Thanks for your input. You were right that the tensors in data were not being moved to the correct device. The solution was the manually call data = data.cuda() inside of every forward call. What I don’t understand is that inside of the training loop when iterating through the batched data within DataLoader, I call data = data.cuda(), but when that data is passed through a model’s forward function that data isn’t sent to the correct GPU. Is the issue that, when the call data = data.cuda() is made inside of the DataLoader loop, it isn’t in communication with DataParallel at that point and therefore is just sent to the default GPU device?

Calling data.cuda() inside the forward method of a module wrapped in nn.DataParallel will move it to the currently set device, which is where the parameters were also moved to.
nn.DataParallel will use parallel_apply here which will set the device here.
If you do not specify a device ID, the set device will be used, and will thus fix your error.

Calling data.cuda() outside will also use the currently set device, which is most likely the default cuda:0 device, and will thus fail.

Got it. So, I just got done converting my code to DistributedDataParallel and it is much faster and everything seems to be working. What I don’t understand is why the data = data.cuda() calls aren’t needed in the forward functions when using DDP but are needed when using DP?

DDP will use a single process per device and assumes each process uses its corresponding GPU only. Usually you would call torch.cuda.set_device as the beginning of your script to make sure all .cuda() or to("cuda") calls are using this specified device.

Great to hear DDP is giving you a speedup!

So, I wanted to follow up on this and get some additional feedback. I have the code working on our local GPU machine and it runs smoothly. I’ve since moved the code to our HPC cluster and I’m running into a problem regarding Exclusive vs Default mode on the GPU. Essentially, the error is:

RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable

After speaking with our HPC support staff the issue is that when in Exclusive mode each GPU cannot have multiple processes spawned on it. However, when running the below line in my main script:

mp.spawn(run_training, args=(partitioned_data, ml.parameters, ml.model,ml), nprocs=ml_parameters['world_size'],
                 join=True)

And setting world_size equal to the number of GPUs, my understanding is that each GPU should only have 1 processes running on it. Is that correct? The rest of my code that uses DDP is below:

def ddp_setup(rank: int,world_size):
    """
    Args:
    rank: Unique identifier of each process
    world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    torch.cuda.set_device(rank)
    init_process_group(backend="gloo", rank=rank, world_size=world_size)

and

loader_train = DataLoader(data['training'], batch_size=parameters['BATCH_SIZE'], shuffle=False, follow_batch=follow_batch,sampler=DistributedSampler(data['training']))

I assume you are referring to EXCLUSIVE_PROCESS set via nvidia-smi. If so, then only a single process is allowed to initialize a CUDA context on this device and multiple threads may submit work to this context.
If your approach tries to create multiple CUDA contexts on the same device, your script is wrong and you might not be setting the device correctly as mentioned above:

Usually you would call torch.cuda.set_device as the beginning of your scrip […]

Isn’t that what I’m doing here though:

def ddp_setup(rank: int,world_size):
    """
    Args:
    rank: Unique identifier of each process
    world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    torch.cuda.set_device(rank)
    init_process_group(backend="gloo", rank=rank, world_size=world_size)

I’m also confused about this:
If so, then only a single process is allowed to initialize a CUDA context on this device and multiple threads may submit work to this context.

If I only have 2 GPUs and request 2 processes which should be set to either rank 0 or rank 1, then how can either GPU say that there is more than 1 process running on it? Is there a way that I can debug this somehow via simple print commands, because I don’t think I understand what is happening at a fundamental level? If it helps, here’s the full error:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/share/pkg.7/python3/3.9.4/install/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/share/pkg.7/python3/3.9.4/install/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)                                                                             File "/project/cmdlab/software/python3.9.4/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 147, in rebuild_cuda_tensor                                                                                                     storage = storage_cls._new_shared_cuda(
  File "/project/cmdlab/software/python3.9.4/lib/python3.9/site-packages/torch/storage.py", line 1085, in _new_shared_cuda
    return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.