[Solved] Memory Usage Increasing Constantly

While training an autoencoder my memory usage will constantly increase over time (using up the full ~64GB available). Initially I thought it was just the loss function, buy I get the same behavior with both BCELoss and MSELoss.
image

I’ve read the FAQ about memory increasing and ensured that I’m not unintentionally keeping gradients in memory.

Attempting to split the data into mini-batches (“chunks” in the code example) does not change the behavior at all.

    opt = torch.optim.Adam(net.parameters())
    loss_fn = nn.BCELoss()
    # for reducing RAM usage use subsets of training data
    subsets = 8
    chunks = train_ingr.split(len(indexes)//subsets)
    for i in range(epochs):
        loss = loss_fn(net(chunks[i%subsets]), chunks[i%subsets])
        loss.backward()
        train_loss_hist.append(loss.detach().numpy())
        opt.step()
        opt.zero_grad()

I don’t think I’m doing anything in particular wrong. The network being trained is a simple dense linear autoencoder with activations. Everything is on CPU, so there shouldn’t be accidental copies created by moving data around.

Please let me know if I’ve missed something important.

Edit: adding memory profiling results:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                    aten::empty_strided         0.10%      41.003ms         0.10%      41.003ms       1.544us      48.96 Gb      48.96 Gb         26562
                                            aten::addmm         6.94%        2.900s        10.67%        4.457s       1.431ms      26.50 Gb      26.50 Gb          3114
                                         aten::_softmax        11.82%        4.940s        11.82%        4.940s       9.518ms      25.85 Gb      25.85 Gb           519
                           aten::_softmax_backward_data         9.73%        4.064s         9.73%        4.064s       8.127ms      24.02 Gb      24.02 Gb           500
                                       aten::empty_like         0.02%       8.228ms         0.08%      35.009ms      22.689us      49.73 Gb       2.55 Gb          1543
                                               aten::mm        13.40%        5.600s        13.40%        5.600s       1.018ms     821.60 Mb     821.58 Mb          5500
                                             aten::mish         3.77%        1.574s         3.77%        1.574s     758.274us     652.97 Mb     652.97 Mb          2076
                                    aten::mish_backward         4.81%        2.009s         4.81%        2.012s       1.006ms     605.75 Mb     605.75 Mb          2000
                                           aten::linear         0.12%      52.147ms        10.79%        4.509s       1.448ms      26.50 Gb     289.33 Mb          3114
                                             aten::sqrt         0.10%      41.426ms         0.10%      41.426ms       6.904us     205.93 Mb     205.93 Mb          6000
                                              aten::div         0.11%      44.089ms         0.20%      84.888ms      14.148us     205.93 Mb     205.46 Mb          6000
                                             aten::tanh         0.03%      10.577ms         0.03%      10.577ms      20.380us      16.32 Mb      16.32 Mb           519
                                    aten::tanh_backward         0.02%       7.644ms         0.02%       7.644ms      15.288us      14.93 Mb      14.93 Mb           500
                                          MishBackward0         0.23%      94.971ms         4.83%        2.018s       1.009ms     590.21 Mb      12.26 Mb          2000
                                              aten::sum         5.38%        2.246s         5.42%        2.264s     643.471us       6.34 Mb       6.31 Mb          3519
                                               aten::to         0.10%      41.906ms         0.39%     163.319ms       6.392us     529.35 Kb     441.54 Kb         25550
                                            aten::zero_         0.00%      55.000us         0.00%     198.000us       8.250us     202.69 Kb     202.69 Kb            24
                                         aten::_to_copy         0.25%     102.501ms         0.32%     133.250ms       5.326us     112.85 Kb      37.05 Kb         25019
                                            aten::empty         0.01%       4.462ms         0.01%       4.462ms       1.763us      31.47 Kb      31.47 Kb          2531
                                     aten::resolve_conj         0.00%     231.000us         0.00%     231.000us       0.011us      15.75 Kb      15.75 Kb         20247
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

Tried again with different loss functions and optimizers.

Optimizers
SGD: No difference, memory increases in the same way
Adadelta: No difference, memory increases in the same way

Loss Functions
A custom implementation of Dice Loss: no problem, there’s barely any increase in memory over time:

It looks like the other loss functions (BCELoss and MSELoss) are related to the memory issue. Unfortunately, BCELoss is the only one that works well for my application so far.

In case there was something being stored internal to the loss function I tried creating a new instance of it (nn.BCELoss) each epoch, but this had no effect on the memory usage (still constantly increasing).

I’m guessing that the issue is a combination of data being generated by the loss function and stored by the optimizer. Does that sound right? Is there a good way to address the issue?

(BTW, I’ve tried using half-precision but there’s a know bug that makes it tremendously slow on CPU)

On a whim I changed from loss.detach().numpy() to loss.item() and the memory issues disappeared.
Could someone explain if this makes sense? I thought they would be roughly equivalent for my setup.

What’s even more strange is that this behavior was not consistent across loss functions (a bug?).