# Efficient Backprop for Virtual Adversarial Training

Hi,

I’m trying to implement a version of the Virtual Adversarial Training model from this paper: https://arxiv.org/pdf/1704.03976.pdf in PyTorch. I have a successful version, but I think it’s not as efficient as it could be. The problem is the following (in pseudo-code to make things clearer):

In order to do the power iteration step to approximate the adversarial direction as in the paper, we need to call forward on the model twice:

logits_1 = model(inputs)
logits_2 = model(inputs+r_random)

compute the cross-entropy, and backdrop to get the derivative with respect to r_random:

xentropy = cross_entropy(logits_1.detach(), logits_2)

xentropy.backward()

At this point, we don’t want any of the accumulated gradients to be used in the update, we just wanted to find r_adversarial, so we zero the gradients:

Then to get the total loss, we have to call forward two more times:

logits_3 = model(inputs)

nll_loss = cross_entropy(logits_3, labels)

My question is, in the first calculation, we already calculated model(inputs) - is there any way to save the forward part of that graph without accumulating any gradients when we calculate r_adversarial, so that we don’t have to calculate model(inputs) a second time? That should save a fair amount of time, since we would only need three forward passes instead of four.

Thanks!
Shawn

3 Likes

@Shawn_Henry Any luck on that front?

Would passing retain_graph not work?

1 Like

I also implemented VAT, which is reduced forward time to three, but soon I noticed that it can’t achieve without gradient accumulation.
I would like to hear whether that is feasible or not.

``````adversarial_loss = cross_entropy(logits_1.detach(), logits_4)
With that solution the only place where `logits_1` is used without `.detach()` is in the calculation of `nll_loss`, so I think everything will work out exactly as desired.
The fact that you zero the grads earlier does not affect the final grads because the final grads aren’t calculated until you run `loss.backward()`.