Memory usage when doing Monte Carlo sampling during optimization


I’m working with energy-based models, using contrastive divergence to train my model.

In each iteration, I’m generating samples from my model using the Markov chain Monte Carlo algorithm:
mcmc_samples = generate_samples(weights, n_samples)
Then I calculate the energy of these samples and the energy of my train samples. The loss is the delta between the mean of the energies.

The complete training process looks something like this:

weights = torch.nn.Parameter(...)
optimizer = torch.optim.SGD([weights], lr=lr)

for i in range(iterations):
    mcmc_samples = generate_samples(weights, n_samples)
    mcmc_energies = get_energies(weights, mcmc_samples)

    train_samples_energies = get_energies(weights, samples)

    loss = - mcmc_energies.mean() + train_samples_energies.mean()

This works well for small models. For large models I get a cuda out of memory exception when calling generate_samples.

I do not get an exception when I call generate_samples with the same arguments outside the training process, when weights is a tensor and not a parameter.

I understand that this is because more memory is required to store the data later needed for gradient calculation when using parameters. But in this case, the gradient is independent of the generate_samples step.

Is there a way to tell PyTorch to ignore the generate_samples line of code during the gradient calculation?
Maybe there’s a different approach to using PyTorch with Monte Carlo sampling?

Currently, I solved this by computing the gradients on my own and implementing the optimization step as well, but I would like to use PyTorch automatic differentiation and out-of-the-box optimizers.


Are you detaching your samples? Otherwise, Autograd will backprop through your sampling process as well which you don’t want at all (and will lead to an ever-increasing memory footprint)

Actually no, that means that inside my sampler I should detach after every operation involving my model parameters?


Thanks @AlphaBetaGamma96! Inside generate_samples I simply detached every tensor resulting from a calculation with my model parameters and it solved the issue.