Hi,
I’m trying to implement a version of the Virtual Adversarial Training model from this paper: https://arxiv.org/pdf/1704.03976.pdf in PyTorch. I have a successful version, but I think it’s not as efficient as it could be. The problem is the following (in pseudo-code to make things clearer):
In order to do the power iteration step to approximate the adversarial direction as in the paper, we need to call forward on the model twice:
logits_1 = model(inputs)
logits_2 = model(inputs+r_random)
compute the cross-entropy, and backdrop to get the derivative with respect to r_random:
xentropy = cross_entropy(logits_1.detach(), logits_2)
xentropy.backward()
r_adversarial = Variable(l2_normalize(r_random.grad.data.clone()))
At this point, we don’t want any of the accumulated gradients to be used in the update, we just wanted to find r_adversarial, so we zero the gradients:
model.zero_grads()
Then to get the total loss, we have to call forward two more times:
logits_3 = model(inputs)
logits_4 = model(inputs+r_adversarial)
adversarial_loss = cross_entropy(logits_3.detach(), logits_4)
nll_loss = cross_entropy(logits_3, labels)
loss = adversarial_loss+nll_loss
My question is, in the first calculation, we already calculated model(inputs) - is there any way to save the forward part of that graph without accumulating any gradients when we calculate r_adversarial, so that we don’t have to calculate model(inputs) a second time? That should save a fair amount of time, since we would only need three forward passes instead of four.
Thanks!
Shawn