Out of memory when resuming DDP training

I got OOM when resuming DDP training:

device = torch.device("cuda:{}".format(local_rank))
model.to(device)
model.load_state_dict(torch.load(path, map_location=device))
DistributedDataParallel(model, device_ids=[local_rank])

...
optimizer.step() <-- OOM here
...

After I change the map_location to 'cpu', the problem is fixed. But I’m wondering why the first case would fail?

device = torch.device("cuda:{}".format(local_rank))
model.load_state_dict(torch.load(path, map_location=device))
model.to(device)
DistributedDataParallel(model, device_ids=[local_rank])