Only calling .backward() once, but I'm still getting an error telling me to set "retain_graph=True"

Hi, I am attempting to train a Siamese neural network that defines a particular embedding function f(x), while performing optimization simultaneously with a clustering model (Gaussian mixture model) on that embedding space.

I want the NN weights to be updated with respect to both a loss function that measures the quality of the embedding, but also the Gaussian mixture model variational loss (ELBO).

However, when I call loss.backward(), I get the following error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Here is my code. I’ve omitted the GMM code and a few other functions are they are about 1000 lines and don’t seem particularly helpful to include.

class Siamese(nn.Module):
    def __init__(self, input_size):
        super(Siamese, self).__init__()

        self.hidden = nn.Linear(input_size, 2)
        self.out = nn.Linear(2, 2)

    def forward_one(self, x):
        x = self.hidden(x)
        x = self.out(x)
        return x

    def forward(self, x1, x2):
        out1 = self.forward_one(x1)
        out2 = self.forward_one(x2)
        dis = torch.norm(out1 - out2, dim=1)
        return dis

data, labels = load_dataset()
net = Siamese()
net_optim = torch.optim.Adam(net.parameters(), lr=0.05, weight_decay=1)

# initialize weights, means, and covariances for the Gaussian clusters
concentrations, means, covariances, precisions = initialization(net.forward_one(data)) 

for i in range(1000):
    net_optim.zero_grad()
    pairs, pair_labels = pairGenerator(data, labels) # samples some pairs of datapoints
    outputs = net(pairs[:, 0, :], pairs[:, 1, :]) # computes pairwise distances

    embedding = net.forward_one(data) # embeds all data in the NN space

    log_prob, log_likelihoods = expectation_step(embedding, means, precisions, concentrations)
    concentrations, means, covariances, precisions = maximization_step(embedding, log_likelihoods)

    loss = FullLoss(outputs, pair_labels, log_likelihoods, log_prob, precisions, concentrations)

    loss.backward()
    net_optim.step()

FullLoss is a loss function computed based on pairwise distances in the NN embedding space, in addition to the Gaussian mixture model loss (based on likelihoods relative to the current settings of the weights, means, and covariances).

Can anyone tell me why my .backward() call requires retain_graph=True? This is impractical for my use case as by the 400th iteration or so, each iteration is taking 30 minutes.

Thanks!

Based on your description it seems if you set retain_graph=True each iteration will be slower.
This points to a storage of the computation graph, which would also explain the initial error message. Are additionally see a growth in memory usage?

Could you check, if you store any output of the model or the loss in e.g. a python list directly?

Thanks @ptrblck,

This occurs even if I don’t store the loss. Yes, I see a huge growth in memory usage; I am sure that it is storing the entire computation graph (across iterations). Is there a way to prevent this?

PyTorch doesn’t store it automatically, so you should check, if you have some code snippets such as losses.append(loss), which store any tensor which wasn’t detached from the computation graph.

I am definitely not explicitly appending any loss to an array. My optimization code verbatim is:

N, D = data.shape

# Hyperparameters
T = 4
n_iter = 500
batch_size = 10

mean_prec = 1.
dof = D
gamma = 1.5
alpha = 1.5

net = Siamese(D)
net_optim = torch.optim.Adam(net.parameters(), lr=0.05, weight_decay=1)

concentrations, means, covariances, precisions = initialization(net.forward_one(data)) 

for i in tqdm(range(n_iter)):

    net_optim.zero_grad()
    pairs, pair_labels = nnPairIdxsGenerator(data, labels, batch_size, 5)

    embedding = net.forward_one(data)

    outputs = net(embedding[pairs[:, 0]], embedding[pairs[:, 1]])

    mean_lwp, log_likelihoods = compute_expectations(
        embedding, means, precisions, concentrations)
    concentrations, means, covariances, precisions = max_probabilities(
        embedding, log_likelihoods)

    dpmm_loss = elbo(log_likelihoods, precisions, concentrations)
    net_loss = SiameseLoss(outputs, pair_labels)
    net_loss.backward(retain_graph=True)
    dpmm_loss.backward(retain_graph=True)
    net_optim.step()

If it would be useful to see the contents of the methods I am calling in the above, please let me know.

Thanks again!

Yes, a reproducible code snippet would be great so that we can have a look, what’s going on.

Thanks @ptrblck. I omitted this initially as it is a lot of code, sorry about that. Really appreciate your time.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import *
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors
import warnings
warnings.filterwarnings("ignore")

M = 50
data = torch.cat((MultivariateNormal(-7 * torch.ones(2), torch.eye(2)).sample([M]),
                  MultivariateNormal(torch.tensor([-3., 0]), torch.eye(2)).sample([M]),
                  MultivariateNormal(torch.tensor([0, 3.]), torch.eye(2)).sample([M]),
                  MultivariateNormal(torch.tensor([-8, 10.]), torch.eye(2)).sample([M]),
                  MultivariateNormal(torch.tensor([10., -2]), torch.eye(2)).sample([M]),
                  MultivariateNormal(7 * torch.ones(2), torch.eye(2)).sample([M])))

N, D = data.shape
labels = torch.tensor([1 if i < 3*M else 0 for i in range(N)])

class Siamese(nn.Module):
    def __init__(self, input_size):
        super(Siamese, self).__init__()

        self.hidden = nn.Linear(input_size, 2) 
        self.out = nn.Linear(2, 2)

    def forward_one(self, x):
        x = self.hidden(x)
        x = self.out(x)
        return x

    def forward(self, x1, x2):
        out1 = self.forward_one(x1)
        out2 = self.forward_one(x2)
        dis = torch.norm(out1 - out2, dim=1)
        return dis

def SiameseLoss(outputs, labels):
    '''
    outputs: batch size
    labels: batch size

    Returns:
      loss
    '''
    samepairslen = outputs[labels == 1].shape[0]
    diffpairslen = outputs[labels == 0].shape[0]
    if diffpairslen == 0:
        return (samepairslen ** -1) * torch.sum(outputs[labels == 1])
    elif samepairslen == 0:
        return - (diffpairslen ** -1) * torch.sum(outputs[labels == 0])
    else:
        return (samepairslen ** -1) * torch.sum(outputs[labels == 1]) - (diffpairslen ** -1) * torch.sum(outputs[labels == 0])

def nnPairIdxsGenerator(data, labels, batch_size, k):
    '''
      Returns:
        pairs: batch_size * k x 2 
        pairs of INDICES of datapoints 
    '''
    N, D = data.shape
    rindxs = torch.randint(low=0, high=N, size=(batch_size,))

    _, nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(data.detach()).kneighbors(data[rindxs])

    pairs = np.zeros((batch_size * k, 2))
    pair_labels = torch.zeros(batch_size * k)

    for idx in range(0, batch_size):
        for jdx in range(k):
            pairs[idx*k + jdx, 0] = nbrs[idx, 0]
            pairs[idx*k + jdx, 1] = nbrs[idx, jdx+1]
            pair_labels[idx*k + jdx] = (labels[nbrs[idx, 0]] == labels[nbrs[idx, jdx+1]]).type(torch.int)

    return pairs, pair_labels

from sklearn.cluster import KMeans

def betaln(a, b):
    return torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a+b)

def estimate_parameters(data, assignments):
    assignment_counts = torch.sum(assignments, dim=0) + 1e-15
    means = torch.mm(assignments.T, data) / assignment_counts[:, None]
    sq_means = 2 * means * \
        torch.mm(assignments.T, data) / assignment_counts[:, None]

    avg_cov = torch.mm(assignments.T, data * data) / \
        assignment_counts[:, None]
    covariances = avg_cov - sq_means + means**2 + 1e-7
    return assignment_counts, means, covariances


def estimate_weights(assignment_counts):
    # assignment counts shape = (T,)

    stack_cumsum = torch.cat((torch.flip(torch.cumsum(torch.flip(assignment_counts, dims=(0,)), dim=0), dims=(0,))[1:], torch.zeros(1))) # ???
    weight_concentration = (1. + assignment_counts, (alpha +
                                                     stack_cumsum))
    return weight_concentration


def estimate_means(data, assignment_counts, means):
    new_mean_prec = 1. + assignment_counts

    means_unnorm = (1. * torch.mean(data, dim=0) +
                    assignment_counts[:, None] * means)
    new_means = (means_unnorm / new_mean_prec[:, None])
    return new_means, new_mean_prec


def estimate_precisions(data, assignment_counts, means, covariances):
    new_dof = D + assignment_counts
    diff = means - torch.mean(data, dim=0)
    new_covariances = (gamma *
                       torch.var(data, dim=0) +
                       assignment_counts[:, None] *
                       (covariances +
                        (1. /
                         mean_prec)[:, None] *
                           torch.mul(diff, diff)))
    new_covariances /= new_dof[:, None]

    new_precisions = 1. / torch.sqrt(new_covariances)
    return new_covariances, new_precisions, new_dof

def initialization(data):
    global mean_prec
    global dof

    label_matrix = torch.zeros(N, T)
    assignments = KMeans(n_clusters=T, n_init=1).fit(data.detach().numpy()).labels_
    label_matrix[torch.arange(N), assignments] = 1

    counts, means, covariances = estimate_parameters(data, label_matrix)
    weight_conc = estimate_weights(counts)
    new_means, mean_prec = estimate_means(data, counts, means)
    new_covariances, new_precisions, dof = estimate_precisions(
        data, counts, means, covariances)

    return weight_conc, new_means, new_covariances, new_precisions


def compute_expectations(data, means, precisions_pd, concentrations):
    def log_determinant(x): return torch.sum(torch.log(x), dim=1)
    # Estimate log Gaussian probability (requires means, Cholesky precision
    # matrix, data, dof)
    precisions = precisions_pd ** 2
    sum_means_sq = torch.sum((means ** 2 * precisions), dim=1)
    log_prob = (
        sum_means_sq -
        2. *
        torch.mm(
            data,
            (means *
             precisions).T) +
        torch.mm(
            data ** 2,
            precisions.T))
    log_prob_g = (-.5 * (D * np.log(2 * np.pi) + log_prob) +
                  log_determinant(precisions_pd)) - (.5 * D * np.log(dof))

    # Estimate total log probability (requires data, dof)
    new_dofD = (dof - torch.arange(0, D)[:, None])
    log_prob_l = D * np.log(2.) + torch.sum(torch.digamma(.5 *
                                                 new_dofD), dim=0)
    log_prob_total = log_prob_g + 0.5 * (log_prob_l - D / mean_prec)

    # Estimate log weights
    digamma_sum = torch.digamma(concentrations[0] + concentrations[1])
    log_weights = torch.digamma(concentrations[0]) - digamma_sum + torch.cat(
        (torch.zeros(1), torch.cumsum(torch.digamma(concentrations[1]) - digamma_sum, dim=0)[:-1])) # ??? stack

    # Estimate log likelihoods
    log_weighted_prob = log_weights + log_prob_total
    softmax_log_weighted_prob = torch.logsumexp(log_weighted_prob, dim=1)
    # with np.errstate(under='ignore'):
    log_likelihoods = log_weighted_prob - \
        softmax_log_weighted_prob[:, None]
    mean_log_weighted_prob = torch.mean(softmax_log_weighted_prob)

    return mean_log_weighted_prob, log_likelihoods  # shapes are also NxD


def max_probabilities(data, log_likelihoods):
    counts, means, covariances = estimate_parameters(
        data, torch.exp(log_likelihoods))
    weight_conc = estimate_weights(counts)
    new_means, mean_prec = estimate_means(data, counts, means)
    new_covariances, new_precisions, dof = estimate_precisions(
        data, counts, means, covariances)
    return weight_conc, new_means, new_covariances, new_precisions


def elbo(log_likelihoods, precisions, weight_concentrations):
    log_determinant = (
        torch.sum(
            torch.log(precisions),
            dim=1) -
        0.5 *
        D *
        torch.log(dof))

    log_norm = -(dof * log_determinant +
                 dof * D * .5 * np.log(2.) +
                 torch.sum(torch.tensor(torch.lgamma(.5 * (dof - torch.arange(D)[:, None]))), dim=0))

    log_norm = torch.sum(log_norm)
    log_norm_weight = - \
        torch.sum(betaln(weight_concentrations[0], weight_concentrations[1]))
    return (-torch.sum(torch.exp(log_likelihoods) * log_likelihoods) -
            log_norm - log_norm_weight - 0.5 * D * torch.sum(torch.log(mean_prec)))

N, D = data.shape

# Hyperparameters
T = 4
n_iter = 500
batch_size = 10

mean_prec = 1.
dof = D
gamma = 1.5
alpha = 1.5

net = Siamese(D)
net_optim = torch.optim.Adam(net.parameters(), lr=0.05, weight_decay=1)

concentrations, means, covariances, precisions = initialization(net.forward_one(data)) 

for i in range(n_iter):
    net_optim.zero_grad()
    # pairs, pair_labels = pairGenerator(data, labels, batch_size)
    pairs, pair_labels = nnPairIdxsGenerator(data, labels, batch_size, 5)

    embedding = net.forward_one(data)

    outputs = net(embedding[pairs[:, 0]], embedding[pairs[:, 1]])

    mean_lwp, log_likelihoods = compute_expectations(
        embedding, means, precisions, concentrations)
    concentrations, means, covariances, precisions = max_probabilities(
        embedding, log_likelihoods)

    dpmm_loss = elbo(log_likelihoods, precisions, concentrations)
    net_loss = SiameseLoss(outputs, pair_labels)
    net_loss.backward(retain_graph=True)
    dpmm_loss.backward(retain_graph=True)

    net_optim.step()

    
with torch.no_grad():
    _, log_likelihoods = compute_expectations(net.forward_one(data), means, precisions, concentrations)
    assignments = log_likelihoods.argmax(axis=1)

This does not run: data is not defined in the main score. So N, D = data.shape fails.

Sorry @albanD, I missed copying one of my notebook cells. Edited the above.

Also, the above code runs fine (for me, at least) as is –– it’s when the retain_graph=True flags are unset that it breaks.

In this code sample, you have two calls to backward. So for the first one, it is expected that you need to call retain_graph=True.

You problem is that dpmm_loss will still fail (or have an increasingly larger computational graph).
This is because initialization() is a differentiable function: it’s outputs do require gradients. And these outputs are inputs to every iteration of your loop. So every iteration of the loop will backprop this part of the graph (that corresponds to the initialization function).
If you don’t want that, you can detach() the results of the initialization.

Thanks! Really appreciate you taking the time to figure that out.

How should I detach() the results of the initialization? I have done:

concentrations, means, covariances, precisions = initialization(net.forward_one(data))
concentrations = (concentrations[0].detach(),  concentrations[1].detach())
means = means.detach()
covariances = covariances.detach()
precisions = precisions.detach() 

But I’m still experiencing the same slowing of each iteration as I proceed through the loop. The same thing happens if I wrap the initialization with a torch.no_grad().

The problem is that in your loop, precisions (and other) are both used and written to.
This means that the next iteration depends on the operations from the previous iteration (and thus all the ones before).
If you only want backprop to compute gradients for the current iteration, you want to do something similar to what we do in rnn: At the beginning of each iteration, .detach() all the things that happen to have gradients from previous iterations but shouldn’t. In your case, the 4 lines you have above, at the beginning of each iteration.

1 Like

Thanks, that worked perfectly.

1 Like