I discovered recently my 8-GPU training will hang if I have this
if (using DDP, all GPUs saturate at 100%, happens randomly at some epoch in the middle of a job):
backproject() from Atlas/model.py at master · magicleap/Atlas · GitHub)
... volume = torch.zeros( batch, channels, nx * ny * nz, dtype=features.dtype, device=device ) # `valid` shape: [b, nx*ny*nz] if valid.any(): for b in range(batch): volume[b, :, valid[b]] = features[b, :, py[b, valid[b]], px[b, valid[b]]] volume = volume.view(batch, channels, nx, ny, nz) valid = valid.view(batch, 1, nx, ny, nz) return volume, valid
after removing the
if my model trains well. The purpose of the if was to avoid unnecessary index to save time. I do know this might cause GPUs execute different graph and diverge between samples, but is that the main reason? Is it generally not encouraged to have such data-dependent branching code in training? (I found post like this that obviously uses
if which could lead to even more divergence I assume.)
Any insights are appreciated!