I am using a variational model where the shape of my output is (B x S x D), where S is the number of samples I draw from my encoded distribution, B is the batch size, and D is the feature size. After each forward pass, I only wish to perform a backwards pass of my loss with regard to a single sample (i.e., along dimension S). To do this, I index along S and call backward on only a single entry. I believe autograd is still computing the gradient for all other (S - 1) entries, because I run out of memory at the backwards step. I can fix the error by reducing the size of S, but that significantly harms my model performance. How can I ensure that only a single S entry is being used for the backwards pass (note that this problem is generally the same as finding the backwards pass for a single batch entry, hence the title)?
Sample code:
x, kl_loss = vae(x)
recon_loss = F.mse_loss(x_hat, x.repeat(n_samples, *torch.ones(x.dim(), dtype=int)), reduction='none').mean(dim=-1)
comb_loss = (recon_loss + kl_weight * kl_loss)
max_elbo = torch.argmin(comb_loss, dim=1).detach()
tot_loss = comb_loss[torch.arange(len(x)), max_elbo].mean()
# Free memory associated with all other forward passes
transop_loss = transop_loss[torch.arange(len(x)), max_elbo].mean().item()
kl_loss = kl_loss[torch.arange(len(x)), max_elbo].mean().item()
comb_loss, max_elbo = 0., 0.
# Take gradient step
opt.zero_grad()
tot_loss.backward()
opt.step()