It seems you are storing the complete computation graph in this line of code:
epoch_loss += l1_loss
If you want to use epoch_loss
for printing/debugging purposes (i.e. without wanting to call epoch_loss.backward()
in the future), you should detach the l1_loss
before accumulating it via:
epoch_loss += l1_loss.detach() # or .item()