Hi, I’ve got to re-run a variational inference deep learning model (using reparameterization trick) many times before each backwards step to get accumulated prediction metrics…
But for simplicity I’ll ask about this toy problem: Is it possible change the code below such that mean_output
retains valid gradients but also the gradient memory footprint doesn’t grow with additional samples?
batch = next(iter(train_loader))
batch_inputs = batch['x'].cuda()
mean_output = 0
n_samples = 10000 # I want gradient tape memory to remain constant with more samples
for i in range(n_samples):
mean_output += stochastic_model(batch_inputs)/n_samples
P.S. A key point is that the accumulation of samples doesn’t store intermediate tensors.