I have been trying to minimize my memory footprint while training with DDP and came across ZeroRedundancyOptimizer
. I implemented it similarly to that from the example, by simply wrapping an AdamW
class into it. This is the (rough) example of the training loop I am using that is called from torch.multiprocessing.spawn
.
device = f'cuda:{gpu}'
torch.cuda.set_device(gpu) # From 0 to N-1 for the number of GPUs
net = net.to(device)
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank],
gradient_as_bucket_view=True,
find_unused_parameters=True)
scaler = GradScaler()
optimizer = ZeroRedundancyOptimizer(net.parameters(),
optim=torch.optim.AdamW, lr=learning_rate,)
for sample, ground_truth in train_loader:
optimizer.zero_grad(set_to_none=True)
with autocast():
pred = net(sample)
loss = loss_fn(pred, ground_truth)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
However, I am getting the error Optimizer state has not been consolidated. Returning the local state Please call "consolidate_state_dict()" beforehand if you meant to save the global state
.
Am I doing something wrong with this implementation?
I am running torch == 1.8.1
and cudatoolkit==10.2