So my pytorch model looks like this:
gathered_output = torch.nn.parallel.data_parallel(model, model_input, range(ngpu)) optimizer.zero_grad() loss = F.loss(gathered_output, ground_truth) loss.backward() optimizer.step()
I am doing some large image segmentation task. And GPU 0 uses almost 20% more memory. Is there a way I can mitigate this? The data gathered back isn’t that much just a [8,128,300,400] float tensor in cuda. I have 4 GPU’s, and each GPU works on 2 batches. But I guess the gradients are all gathered back too?
So here I used the functional version of data_parallel, is there a difference to using the wrapped version?
If i compute the loss on each gpu, and only gather back the losses will i save more memory? please help