Need for dividing the loss by accumulating steps when using amp

Hello all, I am trying to train a model using the automatic mixed precision library by nvidia so here is how the code looks like :

def train(train_dataset, train_dataloader, model, device, optimizer, scheduler):
    model.train()
    train_final_loss = utils.AverageMeter() # This is used to calculate the running averag of loss accross batches
    tk0 = tqdm(train_dataloader, total=len(train_dataloader))

    model.zero_grad()
    for batch_index, data in enumerate(tk0):
        
        images = torch.stack(data[IDX_IMAGES]).to(device, dtype=torch.float32)
        targets = data[IDX_TARGETS]
        image_ids = data[IDX_IMAGE_IDS]

        batch_size = images.shape[0]

        boxes = [target['boxes'].to(device, dtype=torch.float32) for target in targets]
        labels = [target['labels'].to(device, dtype=torch.float32) for target in targets]
        target_res = {'cls' : labels, 'bbox' : boxes}

        outputs = model(images, target_res)
        loss = outputs['loss']
        loss = loss / Config.ACCUMULATE 
        train_final_loss.update(loss.item(), batch_size)
        if Config.USE_GRADIENT_ACCUMULATION:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        
        if (batch_index + 1) % Config.ACCUMULATE == 0 : 
            optimizer.step()
            optimizer.zero_grad()
        tk0.set_postfix(loss=train_final_loss.avg)

    return train_final_loss.avg

I was seeing that the functions need to have its loss divided the accumulation steps. is that reallly necessary? from my understanding, the gradient accumulation will collect the gradients for accumulation steps and then apply it to the network at the end of that number of step, hence it becomes wrong in dividing the loss by ACCUMULATION_STEPS is wrong. please let me know if I am in right direction.

I’m not sure, if this question is related to amp or the general usage of scaling the loss by the gradient accumulation steps.
If I’m not mistaken, you would divide the loss by the step size to accumulate the gradients in such a way, which would keep the gradient magnitude as it would be using the “bigger batch size”.

If that doesn’t fir your use case, you could remove it.

Hi @ptrblck, sorry for the mixing of concepts. Yes this question was targeted to check on scaling of loss when using gradient accumulation. Currently this is what I am using with Pytorch’s 1.6 amp package.

                import torch.cuda.amp as amp_py

                scaler = amp_py.GradScaler()
                with amp_py.autocast():
                    outputs = model(images, target_res)
                    loss = outputs['loss']
                scaler.scale(loss).backward()
                train_final_loss.update(loss.item(), batch_size) # An instance of averagemeter class that calculates the average on fly
                if (batch_index + 1) % Config.ACCUMULATE == 0: 
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

and this seems to training well for now.
Can you please elaborate this :

If I’m not mistaken, you would divide the loss by the step size to accumulate the gradients in such a way, which would keep the gradient magnitude as it would be using the “bigger batch size”.

from what I understand is that you are not immediately using the gradients to update the training procedure rather you are adding the gradients for steps and after it reaches a threshold you are making the optimiser to update the weights. I am quite unsure how scaling the loss benefits this.

Accumulating the scaled loss (loss divided by the accumulation steps) would simulate training using a bigger batch size.
E.g. you could perform 10 iterations with a batch size of 1 and accumulate the gradients created by the iter_loss/10, which would create the same gradients as training with a batch size of 10.
This is often used, if you would be running out of memory, but the training would generally benefit from a larger batch size.

Here is a simple code snippet:

lin = nn.Linear(1, 1)

x = torch.randn(10, 1)
y = torch.randn(10, 1)
criterion = nn.MSELoss()

# gradient accumulation approach
for i in range(x.size(0)):
    x_ = x[i:i+1]
    y_ = y[i:i+1]
    out = lin(x_)
    loss = criterion(out, y_)
    loss = loss / x.size(0)
    loss.backward()

print(lin.weight.grad, lin.bias.grad)
> tensor([[-0.9022]]) tensor([0.5286])

# "large batch" approach
lin.zero_grad()
out = lin(x)
loss = criterion(out, y)
loss.backward()

print(lin.weight.grad, lin.bias.grad)
> tensor([[-0.9022]]) tensor([0.5286])

This might not fit your use case and you could thus remove the scaling step.

2 Likes

Thanks @ptrblck for the explanation.