I am dealing with memory crunch on 4 11GB-1080Ti GPU machine. The problem I have is my model with a batch size of 32 fits on memory with 2 GB to spare on each but when comes to computing loss I get out of memory error. I put pdb.set_trace() and saw the memory usage on each step. It seems my loss function takes 3GB memory only on the first GPU which leads to out of memory error.
here is the link to loss function that I am using for object detection using online hard mining.
How can I handle this efficiently?
Is it possible to put loss function is data parallel module? If yes how would that work and require changes in loss function?
I am not sure how to handle it?
Thanks in advance.