For loop backward memory leak for batch size > 1

Hi, I defined a model (albeit somewhat naively) with many nested for-loops. The forward pass works fine for all batch sizes, but for batch sizes > 1, the backward pass explodes the memory footprint and the python kernel crashes the system (I only have 8GB on this particular system). To give you an idea of the discrepancy, the batch size 1 does takes up less than 2% of the total system memory (part of the network is not shown), but batch size 2 takes over 95% (crashes before it gets higher). I was wondering if there is some mistake I’m making? I’m pretty sure the offending code is below:

    def prim_to_uhat(self, x):
        u_hat = Variable(torch.FloatTensor(x.size(0), 32, 6, 6, 10, 16))
        for q in range(x.size(0)):
            for i in range(0, 10):
                for j in range(0, 32):
                    for k in range(0, 6):
                        for l in range(0, 6):
                            u_hat[:,j,k,l,i,:] = self.wij[i][j][k][l](x[:,j,k,l,:])
         return u_hat

Use vectorized operation rather than nested for loops please.

I agree that it should be this way, but why does the memory explode disproportionally with batch’s > 1 when using nested for loops? Is it a code problem or a limitation of pytorch?

Without seeing the complete network I can’t tell. However, a number of things can be different between batch size 1 and batch size 2. E.g., tensors that are contiguous with batch size 1 might not be with larger batch size.

So your saying that perhaps the size of the tensor may be too large to fit into a single memory “block”? (i dont know much about how memory is allocated), and thus takes up more memory than originally thought?

Any chance you could look at the network to see? I have it linked here:

What I mean by “being contiguous” is that the values in the tensor lies in a contiguous chunk of memory according to the the order of dimensions (so first dim has largest stride, and last dim has stride 1, etc.). E.g., torch.randn(1,5).t() is still contiguous, but torch.randn(2,5).t() is not.

I suggest just writing it with vectorized operations. Your current code will be very slow anyways.

Thanks, I will definitely write with vectorized operations I just wanted to know why it exploded so that I don’t make that mistake in the future.

So, are you implying that torch.randn(2,5).t() is not contiguous but torch.randn(5,2) would be because of transpose()? Sorry, is there a reference on the nuances of autograd and contiguity? Also, for what its worth, I printed the value of x.is_contiguous() for the input as it went through the network and it says True for each step.


I’m not saying that the issue is contiguity. I’m listing an example how things can be different with different batch size :slight_smile:

Ah ok. Thank you Simon for your help. I really do appreciate it! If I wanted to figure out exactly why the batch size messes up the backward pass, how would I debug this? I can’t just use print statements on the forward pass right? Would it involve hooks?

You can, but I’d suggest just writing better code. DL frameworks, including PyTorch, aren’t designed for nested for loops. You probably will need to change it anyways.