Gradient accumulation in an RNN with AMP

I ran into some memory issues when running a large RNN network, but I want to keep my batch size reasonable so I wanted to try out gradient accumulation. In a network where you predict the output in one go, that seems self-evident but in an RNN you do multiple forward passes for each input step. Because of that, I fear that my implementation does not work as intended. I started from @albanD’s excellent examples here , but I think they should be modified when using an RNN. The reason I think that is because you accumulate much more gradients because you do multiple forwards per sequence.

My current implementation looks like this, at the same time allowing for AMP in PyTorch 1.6 which seems important - everything needs to be called in the right place. Note that this is just an abstract version, which might seem like a lot of code but it is mostly comments.

def train(epochs):
    """Main training loop. Loops for `epoch` number of epochs. Calls `process`."""
    for epoch in range(1, epochs + 1):
        train_loss = process("train")
        valid_loss = process("valid")
        # ... check whether we improved over earlier epochs
        if lr_scheduler:
            lr_scheduler.step(valid_loss)
        
def process(do):
    """Do a single epoch run through the dataloader of the training or validation set. 
       Also takes care of optimizing the model after every `gradient_accumulation_steps` steps.
       Calls `step` for each batch where it gets the loss from."""
    if do == "train":
        model.train()
        torch.set_grad_enabled(True)
    else:
        model.eval()
        torch.set_grad_enabled(False)
    
    loss = 0.
    for batch_idx, batch in enumerate(dataloaders[do]):
        step_loss, avg_step_loss = step(batch)
        loss += avg_step_loss

        if do == "train":
            if amp:
                scaler.scale(step_loss).backward()

                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    # clip in-place
                    clip_grad_norm_(model.parameters(), 2.0)
                    scaler.step(optimizer)
                    scaler.update()
                    model.zero_grad()
            else:
                step_loss.backward()
                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    clip_grad_norm_(model.parameters(), 2.0)
                    optimizer.step()
                    model.zero_grad()
        
        # return average loss
        return loss / len(dataloaders[do])

    def step():
        """Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
        # do stuff... init hidden state and first input etc.
        loss = torch.tensor([0.]).to(device)
        
        for i in range(target_len):
            with torch.cuda.amp.autocast(enabled=amp):
                # overwrite previous decoder_hidden
                output, decoder_hidden = model(decoder_input, decoder_hidden)

                # compute loss between predicted classes (bs x classes) and correct classes for _this word_
                item_loss = criterion(output, target_tensor[i])

                # We calculate the gradients for the average step so that when
                # we do take an optimizer.step, it takes into account the mean step_loss
                # across batches. So basically (A+B+C)/3 = A/3 + B/3 + C/3
                loss += (item_loss / gradient_accumulation_steps)

            topv, topi = output.topk(1)
            decoder_input = topi.detach()
        
        return loss, loss.item() / target_len

The above does not seem to work as I had hoped, i.e. it still runs into out-of-memory issues very quickly. I think the reason is that step already accumulates so much information, but I am not sure.

Would it be better to call loss backward after each token prediction rather than after each step?

Based on your code it seems you are using alban’s 3rd approach, which uses more memory and is slower than the other approaches, since it’s accumulating the computation graphs in each iteration and cannot free the intermediate tensors.
If you want to save memory, I would recommend to try out the 2nd approach.

Hm, I don’t think so? I call backward before the accumulation block. But because it is a generative (word-by-word) model, I might need to move the whole if do == "train" block in process to inside the for i in range(target_len): loop in step. Is that what you mean?

My thought process being that for a target_len t, we do t forward steps meaning that the returned loss contains a lot of information which may have lead to the increase in GPU memory.

I’m just not sure what the best practice is in a batched recurrent set-up and where to call backward when using accumulation.

I might have misunderstood the code, as you are using a loop over target_len (apparently for the RNN) as well as gradient accumulation (for multiple batches).
I’m not sure if code is missing, but note that inside step it seems you are using the same decoder_input for all steps. This might of course fit your use case, as I’m not familiar with it.

Looking at the amp-specific part: it looks correct, assuming that your gradient accumulation code works fine for the RNN use case.

Sorry for the long reply, but I rather reply in detail.

Some code is indeed missing (see the first comment in step()) because I did not want to bloat the code too much with boiler plate. However, concerning your comment, decoder_input is different at every iteration: it is set to the highest prediction of the previous output in the very last line of the for-loop:

decoder_input = topi.detach()

This is a typical scenario in e.g. machine translation or autoregressive language modelling where you predict a token based on the previous best output. (Some information is missing here, particularly the initial hidden state decoder_hidden, but I don’t think it’s important for the problem I am having.)

So to recapitulate: my problem is that I am running out of memory when using a large model (50k+ embedding size), even with small batch sizes but I believe that this should be able to be remedied by another gradient accumulation set-up. I am just not sure what the normal way-to-go is in an RNN set up where you do multiple forward passes per batch, and this accumulate a lot of history.

Usually you do gradient accumulation on the batch level (accumulate n batches and only then do the optimizer step), whereas here it seems to me to make more sense to this on the sequence level (accumulate n predictions in the for i in range(target_len) loop). I am just not 100% sure whether that makes sense.

I am right in assuming that in this code snippet:

loss += (item_loss / gradient_accumulation_steps)

loss will accumulate all the history of all iterations, right? So if target_len=200, loss will contain the history of 200 forward passes? That’s why I think that the gradient accumulation and optimizer should perhaps be moved to inside that for i in range(target_len) loop.

I think the right way to go is to call backward in the same loop as the model’s forward. The question then becomes how it should be normalised. IAssuming that target_len is always 200, if I call optimizer step in the same spot as I do know, that means that the gradients have been accumulatedgradient_accumulation_steps * 200 times every time optimizer.step is called. If that is the case, then the loss should probably always be normalized by that amount, right? So step would look like this

    def step():
        """Processes one step (one batch) by forwarding multiple times to get a final prediction for a given sequence."""
        # do stuff... init hidden state and first input etc.
        loss = 0.
        
        for i in range(target_len):
            with torch.cuda.amp.autocast(enabled=amp):
                # overwrite previous decoder_hidden
                output, decoder_hidden = model(decoder_input, decoder_hidden)
                # Not sure why I need to detach here
                decoder_hidden = decoder_hidden.detach()

                step_loss = self.criterion(output, target_tensor[i]) / (self.gradient_accumulation_steps * target_len)
                
                loss += step_loss.item()

            if do == "train":
                if self.amp:
                    # for ANP see https://pytorch.org/docs/stable/notes/amp_examples.html
                    self.scaler.scale(step_loss).backward()
                else:
                    step_loss.backward()

            topv, topi = output.topk(1)
            decoder_input = topi.detach()
        
        return loss.item() / target_len

After trying this, I got the “Trying to do a second backward pass” so I hadd to add decoder_hidden = decoder_hidden.detach() although I am not entirely sure why.