How to access hidden states computed on a different device when using DataParallel in PyTorch?

If my batch size is set to 16 and I have 4 devices, PyTorch will divide them into 4 batches of size 4, then send them off to different devices to make the forward call. But I want to access all hidden states of the 16 instances in the forward call, is it possible to do that?

You might be able to create a “global” dict and store the forward activations via a key which uses the layer name as well as the device id. Note that we generally recommend using DistributedDataParallel for a better performance, where one process per device would be used and would allow you to directly access the corresponding forward activations.

1 Like

Thanks! Does DataParallel scatter and distribute the input data in order? Or will the data end up being shuffled?

The DataParallel wrapper will slice the data (in order) and snd each chunk to the corresponding device. The shuffling is done before by the sampler in the DataLoader.

1 Like

I tried adding a dict as a model attribute. The forward call would involve the following:

  1. Compute hidden states using a transformer model, save these hidden states to that dict with current device index as the key.
  2. Gather all hidden states computed on all devices from the dict, perform some further computations.

But when I get to the second step, the dict doesn’t contain all device entries, most likely because some devices would finish before others. However, it seems like the devices don’t execute in order. Sometimes accessing the dict on device 0 will return more than one entry. Is there a way to either regulate the order in which devices run, or synchronize across all devices before going into step 2? Many thanks!

Yes, use DDP as it will allow you for a manual communication via all_reduce, barrier, etc. which might all be needed for your custom approach.

1 Like

Thanks, I will look into that.