Is there a way that allows splitting non-tensor arguments when using torch.nn.DataParallel
?
I am passing annotations to my module for computing the loss, and want to split that list according to the batch size as well. I know I could handle this by transforming the annotations to a tensor
, but as these annotations are objects with quite a lot of different properties, that is something I am not keen on doing.
It would be nice if there would be a callback function you give when initializing the torch.nn.DataParallel
module, just like how the collate_fn
for torch.utils.data.DataLoader
works.