So I have recently moved from a single GPU setting to a multi GPU setting. To date I have been using a custom class for loading data during nn training - will refer to this class as DM, and to Torch’s standard DataLoaders as DL. I do this because I see about a 450-550% speedup in all data loading/formatting ops when using DM as opposed to DL, which radically reduces my training time. In the single-GPU setting, this has 0 impact on the performance of the trained model or any loss statistics etc - in other words, it does nothing but cut down my training time by a factor of 4-5 for no cost to model performance.
But since switching to a distributed setting, this is no longer the case. I have modified my DM class to (apparently naively) handle the distributed case, with each process started by DDP having its own DMₙ on GPU rank n. DMₙ will yield batches of data as tensors on device torch.device(f"cuda:{n}")
for the forward pass executed by the DDP process with rank n. Very simple…or so I thought.
After training in the distributed setting like this for a while, I tested the model’s performance and it’s not just simply not learning, but is actually actively getting worse the more it’s trained.
But the point of this post isn’t a deep diagnosis of the problem…it’s too involved. All I am asking is this: My assumption is that the problem in overall model performance arises from the fact that the only part of my distributed setting which differs from the stock-standard pytorch ddp pipeline is the change from DL to DM, so is there anyone out there that can give me some idea of what else Torch’s DataLoaders do when using a DistributedSampler
? Is there something else going on here other than partitioning the dataset into n subsets for n GPUs and sending the data to the right GPU at the right time?
I also note some distinct and regular differences in patterns displayed by GPU utilization shown in nvitop when I use Torch’s DL vs my own DM…but I won’t get into those just yet. I’m currently working my way through Torch’s DataLoader/DistributedSampler code to see if I can figure out what exactly is happening here myself, but figured I’d make this post to see if anyone out there can give me some high-level understanding of what else these things are doing outside of simply moving data to the right place at the right time.