Only Perform Backwards Pass wrt Single Entry in Batch?

I am using a variational model where the shape of my output is (B x S x D), where S is the number of samples I draw from my encoded distribution, B is the batch size, and D is the feature size. After each forward pass, I only wish to perform a backwards pass of my loss with regard to a single sample (i.e., along dimension S). To do this, I index along S and call backward on only a single entry. I believe autograd is still computing the gradient for all other (S - 1) entries, because I run out of memory at the backwards step. I can fix the error by reducing the size of S, but that significantly harms my model performance. How can I ensure that only a single S entry is being used for the backwards pass (note that this problem is generally the same as finding the backwards pass for a single batch entry, hence the title)?

Sample code:

x, kl_loss = vae(x)
recon_loss = F.mse_loss(x_hat, x.repeat(n_samples,  *torch.ones(x.dim(), dtype=int)), reduction='none').mean(dim=-1)
comb_loss = (recon_loss + kl_weight * kl_loss)
max_elbo = torch.argmin(comb_loss, dim=1).detach()
tot_loss = comb_loss[torch.arange(len(x)), max_elbo].mean()

# Free memory associated with all other forward passes
transop_loss = transop_loss[torch.arange(len(x)), max_elbo].mean().item()
kl_loss = kl_loss[torch.arange(len(x)), max_elbo].mean().item()
comb_loss, max_elbo = 0., 0.

# Take gradient step
opt.zero_grad()
tot_loss.backward()
opt.step()

As I understand, you would like to backward pass on the only example that maximizes the evidence lower bound (ELBO) among all the samples.

In your case, the meaningful gradients will be from the selected indices of maximum ELBO (max_elbo). However, a gradient value of 0 will be passed back to all other samples and the computation graph from the forward pass still contains buffers for those samples. Thus the gradient computation will be done for all the samples (not only the selected ones).

One workaround that I could think of is to follow a two forward pass - one backward pass scheme in the Decoder side, where in the first forward pass (with torch.no_grad()), find out the indices of maximum ELBO (max_elbo), and then in the second forward pass, do the actual gradient calculation.

# let vae has encoder, sampler, decoder
# first forward pass to figure out the max ELBO samples 

mean, var = vae.encoder(x)
samples = vae.sampler(mean, var) # B, S, D

with torch.no_grad():
    x_hat = vae.decoder(samples)
    recon_loss = F.mse_loss(x_hat, x.repeat(n_sa  mples,  *torch.ones(x.dim(), dtype=int)), reduction='none').mean(dim=-1)
    comb_loss = (recon_loss + kl_weight * kl_loss)
    max_elbo = torch.argmin(comb_loss, dim=1).detach()

# select max ELBO samples
samples = samples[torch.arange(len(x)), max_elbo]

# actual forward pass thru decoder
x_hat = vae.decoder(samples)
...

# loss calculated and backward

Hi Arul, thank you for the recommendation. It appears to be working.

Here is a simplified test bed where I validated this. Original method:

import time
import torch
import torch.nn as nn
import torch.nn.functional as F

for num_samples in [5, 10, 100, 1000, 10000, 50000]:
    x = nn.Parameter(torch.randn(20))
    y = torch.randn((128, 20))
    opt = torch.optim.SGD([x], lr=1e-2)

    samp = torch.randn((num_samples, 128, 1))
    y_hat = samp * x

    opt.zero_grad()
    loss = F.mse_loss(y_hat, y.repeat(num_samples, 1, 1), reduction='none').mean(dim=-1)
    min_loss = loss.argmin(dim=0).detach()
    opt_loss = loss[min_loss, torch.arange(len(y))]
    pretime = time.time()
    opt_loss.mean().backward()
    print(f"{num_samples} samples, {time.time() - pretime} secs")

Output:

5 samples, 0.00019860267639160156 secs
10 samples, 0.0001697540283203125 secs
100 samples, 0.0006811618804931641 secs
1000 samples, 0.012559652328491211 secs
10000 samples, 0.16070032119750977 secs
50000 samples, 0.5388662815093994 secs

Your proposed approach:

for num_samples in [5, 10, 100, 1000, 10000, 50000]:
    x = nn.Parameter(torch.randn(20))
    y = torch.randn((128, 20))
    opt = torch.optim.SGD([x], lr=1e-2)

    samp = torch.randn((num_samples, 128, 1))

    with torch.no_grad():
        y_hat = samp * x
        loss = F.mse_loss(y_hat, y.repeat(num_samples, 1, 1), reduction='none').mean(dim=-1)
        min_loss = loss.argmin(dim=0).detach()

    samp = samp[min_loss, torch.arange(len(y))]
    y_hat = samp * x

    opt.zero_grad()
    loss = F.mse_loss(y_hat, y, reduction='none').mean(dim=-1)
    pretime = time.time()
    loss.mean().backward()
    print(f"{num_samples} samples, {time.time() - pretime} secs")

Output:

5 samples, 0.0002472400665283203 secs
10 samples, 0.00019288063049316406 secs
100 samples, 0.00018095970153808594 secs
1000 samples, 0.0001964569091796875 secs
10000 samples, 0.00018358230590820312 secs
50000 samples, 0.0002276897430419922 secs

I am glad it works for your use case and thanks for providing the testbed code here!