Dataparallel chunking for a list of 3d tensors?

Hi everyone,

I have a point cloud learning problem where the data from each sample (point cloud) is encoded as an a x b x c tensor where b and c are fixed across the samples, but a is not. Because of this I can’t use the standard collate function that stacks the data to a 4-dimensional tensor. Instead I put them in a simple python list. When I train as usual using DataParallel, all data ends up in GPU0, so I can’t utilize all GPUs on my system.

Is there a way to make the DataParallel chunk the data based on list index rather than the batch dimension?




This solution is not very clean, but you can override the replicate and scatter on DataParallel to make it work: :slight_smile: Hope this helps!

1 Like

It does – thanks Simon.

Sorry, it should be scatter and gather.

Gotit. I’ll loop back once I gave this a try!

At the end of day, we will need to use from ._functions import Scatter, if Scatter does not support list scattering, then we will not be able to scatter a list of tensors with non-uniform sizes. Do we really have to pad them to the same size? This will cause more memory and communication bandwidth comsumption. The ideal way may be scattering the list first into separate GPUs and then pad them or just do for-loop inside module.

I encouter this situation when I want to training a custom RNN on multiple GPUs.

Is there a way to Scatter a list of objects? I think technically we should be able to do it.

So far I end up with following way.
First I modify torch.nn.utils.rnn.pad_sequence and let it returns the lengths of each sample.

def pad_sequence(sequences, batch_first=False, padding_value=0):
    trailing_dims = sequences[0].size()[1:]
    max_len = max([s.size(0) for s in sequences])
    if batch_first:
        out_dims = (len(sequences), max_len) + trailing_dims
        out_dims = (max_len, len(sequences)) + trailing_dims
    out_tensor = sequences[0].new(*out_dims).fill_(padding_value)
    lengths = sequences[0].new_zeros(len(sequences),dtype=torch.long)
    for i, tensor in enumerate(sequences):
        lengths[i] = tensor.size(0)
        if batch_first:
            out_tensor[i, :lengths[i], ...] = tensor
            out_tensor[:lengths[i], i, ...] = tensor

    return out_tensor, lengths

Then I pass these two variables into the module. You can wrap out_tensor and lengths into a dictionary. Inside module, I for-loop batch-wise, and for each sample or mini-batch with the same sequence length, I for-loop sequence-wise. Finally, I concatenate the output of each sample to evaluate the loss.

Be aware, if you want to do this on multi-GPUs with nn.DataParallel, you should use pad_sequence() with batch_first = True.

Hello, could you tell me exactly how you solved the problem? I used the default function “scatter” and “gather” provided by pytorch but it didn’t work.

Still dealing with this issue today. I need to apply dataparallel with a list of strings as input (since there is no StringTensor) and the scatter function will only duplicate the list, not chunk it. Any better solution that subclassing DataParallel?

Here is my code, in case anyone has use for it:

from torch.nn.parallel._functions import Scatter
from torch.nn.parallel import DataParallel
import torch

# This code was copied from torch.nn.parallel and adapted for DataParallel to chunk lists instead of duplicating them
# (this is really all this code is here for)

def scatter(inputs, target_gpus, dim=0):
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, None, dim, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            size = len(obj) // len(target_gpus)
            return [obj[i * size:(i + 1) * size] for i in range(len(target_gpus))]
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        return [obj for targets in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
        return scatter_map(inputs)
        scatter_map = None

def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
    r"""Scatter with support for kwargs dictionary"""
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
    inputs = tuple(inputs)
    kwargs = tuple(kwargs)
    return inputs, kwargs

class DataParallelV2(DataParallel):
    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
1 Like

Did you modify the gather function as well? I am having the same issue. After I tried your scatter function, training