Hi,
I discovered recently my 8-GPU training will hang if I have this if
(using DDP, all GPUs saturate at 100%, happens randomly at some epoch in the middle of a job):
(modification of backproject()
from Atlas/model.py at master · magicleap/Atlas · GitHub)
...
volume = torch.zeros(
batch, channels, nx * ny * nz, dtype=features.dtype, device=device
)
# `valid` shape: [b, nx*ny*nz]
if valid.any():
for b in range(batch):
volume[b, :, valid[b]] = features[b, :, py[b, valid[b]], px[b, valid[b]]]
volume = volume.view(batch, channels, nx, ny, nz)
valid = valid.view(batch, 1, nx, ny, nz)
return volume, valid
after removing the if
my model trains well. The purpose of the if was to avoid unnecessary index to save time. I do know this might cause GPUs execute different graph and diverge between samples, but is that the main reason? Is it generally not encouraged to have such data-dependent branching code in training? (I found post like this that obviously uses if
which could lead to even more divergence I assume.)
Any insights are appreciated!