DataParallel imbalanced memory usage

You would need to have the following:

class FullModel(nn.Module):
  def __init__(self, model, loss):
    self.model = model
    self.loss = loss

  def forward(self, inputs, targets):
    outputs = self.model(inputs)
    loss = self.loss(outputs, targets)

full_model = DataParallel(FullModel(model, loss), device_ids=[0, 1, 2], ***)
full_model.cuda()

loss = full_model(fake_input, fake_target)
print(loss.size) # returns [3] (If you loss returns a 0-dimensional tensor containing the loss value)
final_loss = loss.sum()

For the optimizer that would depend on the optimizer I guess.
For Adam for example, it need to store one (if I remember correctly?) extra copy of all the weights. And it need this copy to update the weights. That means that if all the gradients are accumulated on a single gpu, then this state should be there as well.
That being said. The size of the weights of your network should not be a very large part of the memory consumption for classic nets. Intermediary states are the most demanding. So hopefully this extra memory as big as the set of weights should not be too big a problem.

8 Likes