If my batch size is set to 16 and I have 4 devices, PyTorch will divide them into 4 batches of size 4, then send them off to different devices to make the forward call. But I want to access all hidden states of the 16 instances in the forward call, is it possible to do that?
You might be able to create a “global” dict
and store the forward activations via a key which uses the layer name as well as the device id. Note that we generally recommend using DistributedDataParallel
for a better performance, where one process per device would be used and would allow you to directly access the corresponding forward activations.
Thanks! Does DataParallel scatter and distribute the input data in order? Or will the data end up being shuffled?
The DataParallel
wrapper will slice the data (in order) and snd each chunk to the corresponding device. The shuffling is done before by the sampler
in the DataLoader
.
I tried adding a dict
as a model attribute. The forward call would involve the following:
- Compute hidden states using a transformer model, save these hidden states to that
dict
with current device index as the key. - Gather all hidden states computed on all devices from the
dict
, perform some further computations.
But when I get to the second step, the dict
doesn’t contain all device entries, most likely because some devices would finish before others. However, it seems like the devices don’t execute in order. Sometimes accessing the dict
on device 0 will return more than one entry. Is there a way to either regulate the order in which devices run, or synchronize across all devices before going into step 2? Many thanks!
Yes, use DDP
as it will allow you for a manual communication via all_reduce
, barrier
, etc. which might all be needed for your custom approach.
Thanks, I will look into that.