Is there a way to handle batch_index-dependent data in PyTorch DataParallel

My PyTorch model has the input contains a dict . That is:

class MyModel(nn.Module):

    def __init__(self):

    def forward(self, input1: Tensor, input2: dict):

I know in DataParallel or DistributedDataParallel, a function called scatter_kwargs will split inputs of type Tensor and replicate other types of data.

However, the keys of input2 are associated with the batch index, e.g. b[0] contains the data of the first dimension, which results in the original batch index no longer corresponding to the keys in input2 after calling scatter_kwargs.

So is there an easy way to divide the data in the dictionary along with the data in the dictionary?

Hey @Randool, as of PyTorch v1.9, DistributedDataParallel only supports Single-Program Single-Device mode, where each process exclusively works on a single model replica. In this case, DistributedDataParallel will not divide forward inputs, and instead will directly pass it to the wrapped local model. Will this be sufficient for your use case?

@mrshenli Thanks! Looks like a good method, I will try it.