I’m trying to train a sequence to sequence VAE, and I keep running out of memory at strange times. Is there any profiling or could there be a reason for this? My sense is that the amount of memory used would scale linearly with the length of the sequence (I have a batch size of one). After i generate my new sequence S’ (of size n x d), I create a pairwise distance matrix of size n x n, but after a few minibatches I run out of memory doing this. I have 16GB of VRAM, but whats more, It routinely seems to fail on sequences that are shorter than previous sequences which ran successfully. Going over my logs it did not run out of memory on a sequence of 1632x1632 but did just fail on a sequence of 1504x1504. What, if anything, could be going on here? I’m happy to provide more details of my implementation if that would help, but it is a straightforward VAE whose encoders and decoders are both based on LSTMs.
Is there a best practice for this sort of thing, managing GPU memory? I guess it’s leaking GPU memory somehow, or something else strange is afoot. Is anyone aware of a profiling tool for this sort of thing? Or am I making a mistake elsewhere?
without looking at code it’s hard to tell if there’s a silly bug. The most common such error is if you are summing the loss across timesteps, like this:
for epoch in range(epocs): for ... in dataset: err = criterion(model(input), target) loss += err #### WRONG, it holds onto the entire graph that `err` was computed from loss += err.data #### CORRECT, just gets the loss' scalar value print(loss) ...
Thanks for the help. I can’t say for sure what was going on but it may have been due to how I had my module set up (i also rebuilt from the latest source and updated cudnn). Basically I had a VAE with submodules encoder and decoder, and the VAE did the reparameterizing stuff and took care of other training boilerplate. After I moved everything to one monolithic module it seems to not be leaking memory anymore.
Or, the memory is ballooning but at a smaller rate, so that I can actually get through an epoch.