RAM usage increases with each data size change

I am iterating over data whose size I change with each epoch. E.g., for the first epoch, let’s say that the images are 224x224, and for the next, my batch transforms make them 200x200 (or whatever). I am decaying my learning rate on plateau with a learning rate scheduler and I am using AdamW as my optimizer.

I have noticed that if I do change the size of the data during model training, then every time I change the data size, my (system) RAM usage grows. I checked to see if this was an error in my batch transform by using the exact same batch transform but having it always pick the same size. When I do this, the RAM usage stays flat indefinitely, which is what I would have expected. There is also no difference if I use a learning rate scheduler or not.

Previously I was changing the image size per-minibatch and I think it’s very understandable that doing this during an epoch could make the call graph size increase. But now that I am only changing image sizes once per epoch, I am surprised that the memory usage is growing.

Is this growth in system RAM use expected?

Memory profiles from python3 -m memory_profiler are below. While the growth in memory in one_minibatch() is a function of my code, the growth in memory in loss.backward() and, to a lesser extent, optimizer.step() seem out of my control. While I would expect that having different sizes within one epoch would make the computational graph grow, I’m not sure why this is happening here given that the size for one epoch is fixed (although variable between epochs).

Changing the size with every epoch:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   160 2322.016 MiB  559.344 MiB          10   @profile
   161                                         def one_epoch(net, optimizer, aug_func):
   162 2322.016 MiB    0.000 MiB          10       if aug_func:
   163 2322.016 MiB    0.000 MiB          10           aug_func.epoch_start()
   164                                                 
   165 2322.016 MiB    0.000 MiB          10       running_loss = 0.0
   166 2322.016 MiB    0.000 MiB          10       n_running_loss = 0
   167 2482.203 MiB  125.359 MiB         330       for i, data in enumerate(train_dataloader):
   168 2411.297 MiB    0.016 MiB         320           n_running_loss += len(data[0])
   169                                         
   170                                                 # Zero the parameter gradients
   171 2411.297 MiB    0.000 MiB         320           optimizer.zero_grad()
   172                                         
   173 2445.062 MiB  751.953 MiB         320           loss = one_minibatch(net, data, aug_func)
   174                                                 
   175                                                 # Backward pass
   176 2482.188 MiB  847.141 MiB         320           loss.backward()
   177                                         
   178                                                 # Update optimizer
   179 2482.203 MiB  198.281 MiB         320           optimizer.step()
   180                                         
   181                                                 # Accumulate statistics
   182 2482.203 MiB    0.109 MiB         320           running_loss += loss.item() * len(data[0])
   183                                                 
   184 2482.203 MiB    0.000 MiB          10       running_loss /= n_running_loss
   185                                         
   186                                             # Re-enable batchnorm and dropout
   187 2482.203 MiB    0.000 MiB          10       net.train()

Keeping the size static:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   160 1016.766 MiB  553.469 MiB          10   @profile
   161                                         def one_epoch(net, optimizer, aug_func):
   162 1016.766 MiB    0.000 MiB          10       if aug_func:
   163 1016.766 MiB    0.000 MiB          10           aug_func.epoch_start()
   164                                                 
   165 1016.766 MiB    0.000 MiB          10       running_loss = 0.0
   166 1016.766 MiB    0.000 MiB          10       n_running_loss = 0
   167 1016.828 MiB   99.656 MiB         330       for i, data in enumerate(train_dataloader):
   168 1016.828 MiB    0.016 MiB         320           n_running_loss += len(data[0])
   169                                         
   170                                                 # Zero the parameter gradients
   171 1016.828 MiB    0.000 MiB         320           optimizer.zero_grad()
   172                                         
   173 1016.828 MiB  107.375 MiB         320           loss = one_minibatch(net, data, aug_func)
   174                                                 
   175                                                 # Backward pass
   176 1016.828 MiB  100.703 MiB         320           loss.backward()
   177                                         
   178                                                 # Update optimizer
   179 1016.828 MiB  155.547 MiB         320           optimizer.step()
   180                                         
   181                                                 # Accumulate statistics
   182 1016.828 MiB    0.062 MiB         320           running_loss += loss.item() * len(data[0])
   183                                                 
   184 1016.828 MiB    0.000 MiB          10       running_loss /= n_running_loss
   185                                         
   186                                             # Re-enable batchnorm and dropout
   187 1016.828 MiB    0.000 MiB          10       net.train()

It seems that if I limit the variety of different shapes of data (e.g., instead of letting x range from 200-224 and y from 200-224, which leads to 224*224 = 50,176 different shapes or so), my batchnorm data don’t become massive and the optimizer also stays reasonable.

So now I have a limited list of x and y sizes that cover enough variety for my needs [200, 212, 216, 224] instead of the full range from 0-224. That list retains most of the dynamic nature that I care about while cutting down the combinatoric number of distinct shapes (x*y) from >50k to (4^2=16).