I’m implementing Adversarial Training which consists of two fairly simple steps.
- Jointly minimize the loss function F(x, θ) + F(x+perturbation, θ)
- Perturbation is a derivative of F(x, θ) w.r.t. x. (scaled by epsilon.)
Can someone verify my implementation of 1) and 2)? Unfortunately, I don’t have anyone to review my pytorch code but this community. I’m supposed to get better performance with adversarial training but I am not.
- the input
xis word embedding. I’m not updating
x. The reason why I put
xas leaf nodes was in order to get the gradient (so that I can use it as a perturbation.)
Below is my implementation.
class Model(): def update(self, batch): self.network.train() # forward propagate. These are answer span (NLP Question Answering span). predicted labels. start, end = self.network(batch) # F(x, θ) loss = F.cross_entropy(start, y) + F.cross_entropy(end, y) # F(x+perturb, θ) loss_adv = self.adversarial_loss(batch, loss, y) loss_total = loss + loss_adv self.optimizer.zero_grad() # jointly optimize F(x, θ) + F(x+perturb, θ) loss_total.backward(retain_graph=False) self.optimizer.step() def adversarial_loss(self, batch, loss, y): """ Input: - batch: batch that is shared with update method. - loss : F(x, θ) - y : y label (answer span) Output: - adversarial loss. F(x+perturb, θ). Scalar tensor. """ self.optimizer.zero_grad() # Remove accumulated grads from the previous batch. loss.backward(retain_graph=True) # backprop. I retained the graph for total_loss. # grad = d(loss)/d((x). Shape:[# of vocab, word emb dimension] grad = self.network.lexicon_encoder.embedding.weight.grad.data grad.detach_() #**QUESTION: should it be loss.detach_()?** perturb = F.normalize(grad, p=2, dim =1) * 5.0 # 5 is the norm of perturbation. Hyperparam. adv_embedding = self.network.lexicon_encoder.embedding.weight + perturb # cache the original embedding's weight (to revert it back later) original_emb = self.network.lexicon_encoder.embedding.weight.data self.network.lexicon_encoder.embedding.weight.data = adv_embedding # forward propagate for F(x+perturb, θ) adv_start, adv_end, adv_pred = self.network(batch) # revert x back to the original state before the perturbation. self.network.lexicon_encoder.embedding.weight.data = original_emb # switch off the embedding training. self.network.lexicon_encoder.embedding.training = False return F.cross_entropy(adv_start, y) + F.cross_entropy(adv_end, y)
Hope I could get some feedback on this. Thanks for reading!