Hello,
I’m implementing Adversarial Training which consists of two fairly simple steps.
- Jointly minimize the loss function F(x, θ) + F(x+perturbation, θ)
- Perturbation is a derivative of F(x, θ) w.r.t. x. (scaled by epsilon.)
Can someone verify my implementation of 1) and 2)? Unfortunately, I don’t have anyone to review my pytorch code but this community. I’m supposed to get better performance with adversarial training but I am not.
Additional info:
- the input
x
is word embedding. I’m not updatingx
. The reason why I putx
as leaf nodes was in order to get the gradient (so that I can use it as a perturbation.)
Below is my implementation.
class Model():
def update(self, batch):
self.network.train()
# forward propagate. These are answer span (NLP Question Answering span). predicted labels.
start, end = self.network(batch)
# F(x, θ)
loss = F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])
# F(x+perturb, θ)
loss_adv = self.adversarial_loss(batch, loss, y)
loss_total = loss + loss_adv
self.optimizer.zero_grad()
# jointly optimize F(x, θ) + F(x+perturb, θ)
loss_total.backward(retain_graph=False)
self.optimizer.step()
def adversarial_loss(self, batch, loss, y):
"""
Input:
- batch: batch that is shared with update method.
- loss : F(x, θ)
- y : y label (answer span)
Output:
- adversarial loss. F(x+perturb, θ). Scalar tensor.
"""
self.optimizer.zero_grad() # Remove accumulated grads from the previous batch.
loss.backward(retain_graph=True) # backprop. I retained the graph for total_loss.
# grad = d(loss)/d((x). Shape:[# of vocab, word emb dimension]
grad = self.network.lexicon_encoder.embedding.weight.grad.data
grad.detach_() #**QUESTION: should it be loss.detach_()?**
perturb = F.normalize(grad, p=2, dim =1) * 5.0 # 5 is the norm of perturbation. Hyperparam.
adv_embedding = self.network.lexicon_encoder.embedding.weight + perturb
# cache the original embedding's weight (to revert it back later)
original_emb = self.network.lexicon_encoder.embedding.weight.data
self.network.lexicon_encoder.embedding.weight.data = adv_embedding
# forward propagate for F(x+perturb, θ)
adv_start, adv_end, adv_pred = self.network(batch)
# revert x back to the original state before the perturbation.
self.network.lexicon_encoder.embedding.weight.data = original_emb
# switch off the embedding training.
self.network.lexicon_encoder.embedding.training = False
return F.cross_entropy(adv_start, y[0]) + F.cross_entropy(adv_end, y[1])
Hope I could get some feedback on this. Thanks for reading!