Hi everyone,
The forward pass of my model increases the memory x4 after I call “loss.backward()”. I understand that it stores buffers used in gradient calculations, but is a x4 increase normal? It prevents me from using a higher number of batches. The memory is constant at 2.2GB at the beginning of every batch.
Here is the code I use.
for i, batch in enumerate(data_iter):
optimizer.zero_grad()
src = batch.text.transpose(0, 1).to(args.device)
tgt = batch.target.transpose(0, 1).to(args.device)
output = model(src)
loss_t = loss_fn(output, tgt)
loss_t.backward()
optimizer.step()
total_loss += loss_t.item()
Thank you.