Reshaping tensors while using model parallelism

I’d like to adapt the DLRM model mentioned here [1] for a recommendations ranking use case that requires model parallelism (I have several TB of training data and large embedding tables). I was hoping to use the ListNet loss from here [2]. However, the original DLRM model uses pointwise loss, and the loss function expects a 1D tensor of shape [batch]. The ListNet loss function that I linked to expects a tensor of shape [batch,items], since each user has a variable number of retrieved items that the model must learn to optimally sort.

I can think of two approaches to adapt the code.

  1. Update the DLRM code so that all forward pass operations expect tensors of shape [batch,items,features] rather than [batch,features]. This feels risky because I would need to modify distributed modules like embedding bags and potentially key-jagged-tensors.

  2. Include a user_id index mapping tensor from my batch to convert the 1D predictions and labels to the 2D format expected by my loss function. This feels less risky because I don’t need to update a bunch of built-in torchrec modules. I would only need to add potentially 10 lines of code between the output of the final forward pass and the loss calculation. However, I plan to use the model parallelism capability of torchrec, and my experience with distributed systems tells me that collecting distributed results often changes the final order of collected rows. I’m worried that the forward pass of the model might reorder records so that they no longer line up with my index mapping tensor that comes from the data loader and doesn’t get forward-passed through the model. I don’t plan to explicitly re-order rows in my forward passes, and I was hoping to use the DLRM design mostly as-is, so I likely have nothing to worry about. Nonetheless, I wanted to confirm, because if my concern is valid the model would learn complete noise. Lastly, would this approach be inefficient during training because I’d likely need to pull the predictions back to the CPU and then back again to the GPU for the reshaping operation?

[1] torchrec/torchrec/models/dlrm.py at main · pytorch/torchrec · GitHub
[2] https://github.com/allegro/allRank/blob/master/allrank/models/losses/listNet.py