Forward pass uses a lot of memory

Hi everyone,

The forward pass of my model increases the memory x4 after I call “loss.backward()”. I understand that it stores buffers used in gradient calculations, but is a x4 increase normal? It prevents me from using a higher number of batches. The memory is constant at 2.2GB at the beginning of every batch.

Here is the code I use.

for i, batch in enumerate(data_iter):
        src = batch.text.transpose(0, 1).to(args.device)
        tgt =, 1).to(args.device)

        output = model(src)
        loss_t = loss_fn(output, tgt)


        total_loss += loss_t.item()

Thank you.


Large memory increase during the backward is expected as you need to have both the forward buffers, the gradient tensors, and the intermediary gradients being computed.