I got OOM when resuming DDP training:
device = torch.device("cuda:{}".format(local_rank))
model.to(device)
model.load_state_dict(torch.load(path, map_location=device))
DistributedDataParallel(model, device_ids=[local_rank])
...
optimizer.step() <-- OOM here
...
After I change the map_location
to 'cpu'
, the problem is fixed. But I’m wondering why the first case would fail?
device = torch.device("cuda:{}".format(local_rank))
model.load_state_dict(torch.load(path, map_location=device))
model.to(device)
DistributedDataParallel(model, device_ids=[local_rank])