DDP + fp16 + gradient accumulation

Hi,
I face a problem related to ddp + fp16 + gradient accumulation.
The following is my code. train_loader spit a mini-batch each time and I use gradient accumulation to reach the actual batch size.
The problem is that the model trained using this DDP training script is worst than the model trained by a single GPU. Any suggestions?

Code for training:

for epoch in range(num_train_epochs):
    for data, labels in tqdm(train_loader, desc=f'training-{epoch}'):
        steps += 1
        if steps % gradient_accumulation_steps == 0:
            with amp.context():
                output = model(data)
                loss = loss_fct(output, labels)
            amp.backward(loss)
            amp.step(model, optim)
        else:
            with model.no_sync():
                with amp.context():
                    output = model(data)
                    loss = loss_fct(output, labels)
                amp.backward(loss)
        train_total_loss += loss.item()

Usually, when switch from local training to distributed training, you might need to re-tune things like learning rate and other hyper-parameters, because 1) the global batch size might be different after introducing more data parallel processes 2) the gradient averaging method used by DDP might or might not offer mathematical equivalence to local training; it depends on the loss function.

Regarding the difference in model quality, does anything change if you disable AMP?

cc @mcarilli @ngimel

1 Like

Yes, when I disable fp16 on single GPU training, the model performs better in the first few logging steps. Eventually, with or without fp16, there is no significant difference. But the speed is very slow without fp16.

Iā€™m not sure if I understand this description correctly, but it seems you are hitting numerical differences early in the training, but the training converges fine?
Also, which amp implementation are you using, as amp.context() and amp.backward() etc. are neither defined in the native amp util. via torch.cuda.amp nor in the deprecated apex.amp implementation.