Mixed precision training implementation

Hi!

So with the release of 1.16 I tried to optimize some of my code for mixed precision and see if I would see a measurable speed up.

I tried two training loops (both with 99% gpu usage on rtx 2060) but I don’t measure any speed up, so I would like to show them here to check if my implementation is correct first. My implementation is based on the code sample at https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/

The train function WITHOUT mixed precision:

def train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, writer):
    model.train()
    data_len = len(train_loader.dataset)
    train_start_time = time.time()
    for batch_idx, _data in enumerate(train_loader):
        spectrograms, labels, input_lengths, label_lengths = _data
        spectrograms, labels = spectrograms.to(device), labels.to(device)

        optimizer.zero_grad()

        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1)  # (time, batch, n_class)

        loss = criterion(output, labels, input_lengths, label_lengths)
        loss.backward()

        writer.add_scalar("Loss/train", loss.item(), iter_meter.get())
        writer.add_scalar("learning_rate", scheduler.get_last_lr()[0], iter_meter.get())
        optimizer.step()
        scheduler.step()
        iter_meter.step()
        
    return loss.item()

And the train function WITH mixed precision

def train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, scaler, writer):
    model.train()
    data_len = len(train_loader.dataset)
    train_start_time = time.time()
    for batch_idx, _data in enumerate(train_loader):

        spectrograms, labels, input_lengths, label_lengths = _data
        spectrograms, labels = spectrograms.to(device), labels.to(device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1)  # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)

        # Mixed precision
        scaler.scale(loss).backward()  # loss.backward()

        scaler.step(optimizer)  # optimizer.step()
        scheduler.step() #Should I also put this steps in the scaler?
        iter_meter.step()

        # Updates the scale for next iteration
        scaler.update()

        writer.add_scalar("Loss/train", loss.item(), iter_meter.get())
        writer.add_scalar("learning_rate", scheduler.get_last_lr()[0], iter_meter.get())
    return loss.item()

Thhe main difference is that I pass the inputs to the model and calculate the loss inside with torch.cuda.amp.autocast(): and later backward with the scaler.

Times for training with both functions take the same time, my thoughts on possible reasons are:

  • Automatic mixed precision training from pytorch (making both code functions equal)
  • I missed something in my implementation
  • My model can’t benefit from a speed up from mixed precision training

In case someone wants to check the original model that is being trained I have it on Github

Any experience eyes that can give it a look?

Could you profile the model execution separately to avoid profiling the data loading as well?
To do this, you could use random CUDATenors as the input and profile the forward (and backward) calls separately:

data = torch.randn(...).cuda()

nb_iters = 100

# warmup
for _ in range(10):
    out = model(data)

torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    out = model(data)
torch.cuda.synchronize()
t1 = time.time()
print('forward ', (t1 - t0)/nb_iters)

# warmup
for _ in range(10):
    out = model(data)
    out.mean().backward()

torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    out = model(data)
    out.mean().backward()
torch.cuda.synchronize()
t1 = time.time()
print('backward ', (t1 - t0)/nb_iters)

or use the profiler in PyTorch.