Memory bloat in gradient descent model

I’m running into memory bloat issues when backpropagating through this loop-heavy simulation. Even though the code looks simple, the backward pass consumes a lot of memory. I’d like to understand if there is a way that I can trim the tree at all or if because I am using ODEs where the current state depends on the previous state this is the most efficient implementation. For context, this model currently utilizes 14gb. When I calculate how much it should be using by hand, accounting for the auto grad nodes, it shouldn’t even take 1gb.

Any insight into whether the loop is causing unnecessary graph storage, or how to refactor it to reduce memory overhead, would be much appreciated!

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import gc
from torch.cuda.amp import autocast


class LIF_ODE(nn.Module):
    def __init__(self):
        super().__init__()
        self.dt = 0.0001
        self.tau = 0.02
        self.V_th = 1.0
        self.V_reset = 0.0
        self.V_rest = 0.0
        self.input_current = nn.Parameter(torch.tensor(1.5, dtype=torch.float32))

    def forward(self):
        T = 3500000 
        V = torch.zeros(T, dtype=torch.float32)
        spikes = torch.zeros(T, dtype=torch.float32)
        v_t = self.V_rest

        for t in range(1, T):
            dv = (-v_t + self.input_current) * (self.dt / self.tau)
            v_t = v_t + dv

            if v_t >= self.V_th:
                spikes[t] = 1.0
                v_t = self.V_reset

            V[t] = v_t

        return spikes

def main():
    model = LIF_ODE()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.0, 0.999))
    num_epochs = 1
    target_spikes = torch.tensor(50.0, dtype=torch.float32)

    for epoch in range(num_epochs):
        optimizer.zero_grad()

        with autocast():
            output = model()

        fr = output.sum() / 10 / 3
        loss = (fr - target_spikes) ** 2
        loss.backward()
        optimizer.step()

        gc.collect()
        print(f"Epoch {epoch}: Loss = {loss.item()}", flush=True)

if __name__ == "__main__":
    main()

Hey,
This tool might be helpful for you for this types of investigation if you haven’t seen it already

Thank you,

Perhaps I can try moving everything over to the GPU and then visualizing things. Currently I am storing my model in RAM and just running everything on the CPU.

Hi Isaac!

I don’t have an opinion about whether this memory usage is to be expected or not (or whether it can be reduced), but I do have a couple of comments.

With each iteration of your loop – 3.5 million of them – you grow your computation graph
by a small number of nodes. With the memory usage you quote (presumably not all used
by your growing computation graph), I get 4000 bytes per loop iteration. This doesn’t seems
outlandish to me (and I don’t have an opinion about whether it can be reduced significantly).

The core use cases that pytorch is designed for don’t involve networks that are millions of
layers deep, and typically involve pumping batches of larger objects (e.g, images) through
those smaller numbers of layers. So a computation-graph node might need to store a rather
large tensor for backpropagation for a smaller number of nodes, rather than a single scalar
for a much larger number of nodes.

As such, I could well believe that requiring a few hundred or a few thousand bytes of
bookkeeping / overhead per node could be considered an acceptable design choice.

As written, the return value of your forward() function, spikes, doesn’t carry
requires_grad = True and is not part of any computation graph. Furthermore, because
spikes depends on v_t >= self.V_th, it is not usefully differentiable, so (regardless of
how much memory your unused computation graph consumes) spikes and your forward()
function don’t give you anything you can backpropagate through.

Right now, each entry in spikes is a (not-usefully-differentiable) step function in v_t.
I’m not saying it would work (I don’t understand your use case), but you could potentially
use sigmoid() as a (usefully) differentiable approximation to your step function. (Of
course, this wouldn’t address the issue with your 3.5-million-iteration computation graph.)

Best.

K. Frank

Hi KFrank,

Thank you so much for your response. I have tried reimplementing what I have done and just taking the gradients manually. This seems to have gotten rid of the memory bloat. It was able to reduce the model from 14gb to <100mb. The functions that I am using in my larger model are pretty differentiable, so it shouldn’t be too bad to convert out of PyTorch into a manual implementation.

Thanks!