Send metadata through DataParallel

My input to the neural network consists of a tensor, which is (BxCxHxW), with B as the batch size. As I understand, torch.nn.DataParallel splits the first dimension into chunks and processes. In my case, I also have some metadata, basically a flag for each input, which is a list of size B. This is essential for me to decide which path go in later stages of the network.

My problem is that DataParallel is not splitting the associated list into same chunks as the input tensor, it is just sending the whole list to each child thread. Is there any way in which I can get the input information as well as the meta data information into each GPU?

For example, with 4 GPUs if my input is a tuple of a tensor of shape (12,3,24,24) and my metadata, a list of size 12. Now, each GPU is showing an input size of (3,3,24,24) and metadata a list of size 12, and I have no way of associating each input to its flag.

If you can make your metadata as a tensor (instead of a python list) , Dataparallel will be able to split appropriately.

1 Like

Thanks! This seems like a good workaround. My flags are actually strings, so I guess I have to map them to one-hot before converting them to Tensors.