What to do with to(device) when using DataParallel?

It’s not clear to me what to do with .to(device) when using DataParallel.

Imagine something like

device = torch.device('cuda:2')
model = nn.DataParallel(model, device_ids=[2,3])

In my model I now have:

def forward(self, input):
    input = input.to(device)

Outside of my model I have

output = output.to(device)
torchmetrics.Accuracy(num_classes=1).to(device)

It raises following questions:

  • Is to(device) now unnecessary? Should I remove it?
  • If not, which ID give I to the device?
  • Is the to(device) handled the same for each case (inside the forward function, outside of it)?

For the input: You should remove input = input.to(device) from inside your forward function. Once you wrap your model with DataParallel, the wrapper will take care of splitting the inputs and placing them on the right device for you. forward will be called once for each device, so you shouldn’t move the input to a specific device.

As for the output: the documentation here DataParallel — PyTorch 2.1 documentation mentions that by default output_device is set to devices[0] (probably your ‘cuda:0’ device). But you can actually pass the desired output device as well to the constructor of DataParallel.

Let me know if you have trouble with this.

Hmm, if I’m not supposed to have .to(device), then I have trouble understanding how PyTorch will know which tensors will need to move to the GPU, esp. for my use case: input is actually a list in my model, and I repeatedly call a subnetwork inside the forward function (for each list entry, in each iteration I call .to(device)) before I use its output on my top-level network part (for which I also call .to(device)).

It’s a bit magical, I agree , but yes, that’s what it does. It goes into your input recursivelly (if it is a list it goes into each element of the list, etc, here: pytorch/scatter_gather.py at 557fbf4261d6517552b48c47be9aa9d289fa28d3 · pytorch/pytorch · GitHub

then for each Tensor it finds, it ultimatelly calls: pytorch/comm.py at 557fbf4261d6517552b48c47be9aa9d289fa28d3 · pytorch/pytorch · GitHub, which scatters the tensor across GPUs.

Let me know if, after trying it out, you have issues.

I tried to use it as a drop-in replacement, so the same setup as in the first post:

device = torch.device('cuda:4')
model = nn.DataParallel(model, device_ids=[4,5])

and as you suggested I removed the to(device) calls from inside the forward function – but I run into this error:

RuntimeError: module must have its parameters and buffers on device cuda:4 (device_ids[0]) but found one of them on device: cpu

So, after some research, it seems, you have to call to(device) on the model and input before passing it to forward, basically moving it to the primary GPU from where it’ll be scattered to other ones? That’s how I understand it right now. So after that, and moving the input tensors to the same device, it stops returning that error message.

model = nn.DataParallel(model, device_ids=[4,5]).to(device)
model([[x.to(device) for x in y] for y in inputs], lengths.to(device), g=False)

… and allows me to reach this point in my forward function, which previously had to(device) and now doesn’t, where I dynamically create a padding mask:

mask = torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)

And returns this error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cpu!

So I’m not sure, if I’m now totally on the wrong track but the last error at least suggest that I need some kind of to(device) inside forward, but of course it should be the corresponding device and not a fixed one.

If you need to create new tensors in the forward method, push them to the corresponding device by using the .device attribute of any parameter or the input:

mask = ...
mask = mask.to(x.device) # x is the input to forward
2 Likes

Moving the newly created tensor to the GPU seems to work. However, there’s a dimension mismatch.

It seems it doesn’t magically split my list of list into two batches of lists of lists, so it just passes the whole object to both GPUs. On the other hand, it seems to work for the lengths tensor causing a dimension mismatch at some point.

AssertionError: expecting key_padding_mask shape of (28, 49), but got torch.Size([14, 49])

Actually, looking more closely at the scatter_map function, it might be that in case of a list (or list of lists), it doesn’t split the list in half*, but each individual tensor, meaning that both GPUs get a list of the same size as the original one, but with fractured tensors? Is that right? And if it’s the case, is there any way to change that behaviour?

*Assuming two GPUs.