Advice for using DDP for dictionary input to models

Hi! I’m trying to set up multi-GPU training for my sort of complicated model, but I’m not sure if I’m pursuing the correct direction. I’ve been reading tutorials including this and this, but I just realized according to comments in the class, DDP is splitting the first dimension of input, while I have a dictionary as input. My model is like a multi-encoder-decoder framework over entities and relations, where we have a number of encoders for entities and a number of pair-wise decoders for relations, and the objective is link prediction. Each batch of input would be a dictionary like

    'entity_A': some_tensor,  # size is [num_A, seq_len_A]
    'entity_B': some_tensor,  # size is [num_B, seq_len_B]
    'A_B': some_tensor,  # size is [num_AB, 3], 3 = head, relation, tail
    'A_C': some_tensor,  # size is [num_AC, 3], 3 = head, relation, tail

I’m thinking of manually requiring DDP to only split the relations to each device, and all entities in the batch will be input to the forward. Then, in the model, we only keep the entities that those relations encompass (including remapping the indices). I’m not sure if this is the best way though, and am not familiar with how to “require DDP to only split certain inputs to forward” either.

I am not using pytorch-geometric, but I am looking at their usage of DDP here. It looks like one can customize how data is fed into each machine in DistributedSampler (?) and DDP isn’t directly truncating the multiple inputs in the same way. But I’m still a bit unsure how DDP would work in those cases. Any pointers & references will be greatly appreciated!

DDP would not split the input, it would receive a mini-batch input from the dataloader directly. But I think it would require the input be mostly tensor or tuple of tensors. One thing you can try is to extract the input tensor from the dict and feed the input tensor to a DDP-wrapped model.

1 Like

Oh yes that makes sense, so the customization needs to be made in the DistributedSampler? Could you point me to some codebases where they made such edits? Thank you!