My PyTorch model has the input contains a
dict . That is:
class MyModel(nn.Module): def __init__(self): ... def forward(self, input1: Tensor, input2: dict): ...
I know in
DistributedDataParallel, a function called
scatter_kwargs will split inputs of type
Tensor and replicate other types of data.
However, the keys of
input2 are associated with the batch index, e.g.
b contains the data of the first dimension, which results in the original batch index no longer corresponding to the keys in
input2 after calling
So is there an easy way to divide the data in the dictionary along with the data in the dictionary?