My PyTorch model has the input contains a dict
. That is:
class MyModel(nn.Module):
def __init__(self):
...
def forward(self, input1: Tensor, input2: dict):
...
I know in DataParallel
or DistributedDataParallel
, a function called scatter_kwargs
will split inputs of type Tensor
and replicate other types of data.
However, the keys of input2
are associated with the batch index, e.g. b[0]
contains the data of the first dimension, which results in the original batch index no longer corresponding to the keys in input2
after calling scatter_kwargs
.
So is there an easy way to divide the data in the dictionary along with the data in the dictionary?