The forward pass of my model increases the memory x4 after I call “loss.backward()”. I understand that it stores buffers used in gradient calculations, but is a x4 increase normal? It prevents me from using a higher number of batches. The memory is constant at 2.2GB at the beginning of every batch.
Here is the code I use.
for i, batch in enumerate(data_iter): optimizer.zero_grad() src = batch.text.transpose(0, 1).to(args.device) tgt = batch.target.transpose(0, 1).to(args.device) output = model(src) loss_t = loss_fn(output, tgt) loss_t.backward() optimizer.step() total_loss += loss_t.item()