How to skip some batch when using distributed model for training?

I use torch.multiprocessing to launch distributed training, but some batches may raise cuda_out_of_memory exception. I just wanna skip these batches. I can successfully skip them when using only one GPU for traning by using try and except.

But it dosen’t work for distributed training case, the training process will just stuck. I guess it may caused by the communication between different threads.

I’d be appreciated if someone could help me.

Yes, doing something similar will indeed result in stuckness issues because some ranks will likely have kicked off communication, and other ranks could be skipping the batch so there will be an inconsistency in terms of no. of collective calls launched.

The tricky thing here is that you need to know a-priori whether a batch will be skipped or not, so that it can be communicated consistently across all distributed processes.

Do you know why some batches result in OOM? Are the sizes somehow different across batches, and can you use the size to estimate whether you’ll need to skip it?

thanks! it is a object detection task, the memory will fluctuate because the targets number varies. and i set a maximum target number now, thus i can successfully train the model. :grinning: