Adversarial Training implementation in pytorch

Hello,
I’m implementing Adversarial Training which consists of two fairly simple steps.

  1. Jointly minimize the loss function F(x, θ) + F(x+perturbation, θ)
  2. Perturbation is a derivative of F(x, θ) w.r.t. x. (scaled by epsilon.)

Can someone verify my implementation of 1) and 2)? Unfortunately, I don’t have anyone to review my pytorch code but this community. I’m supposed to get better performance with adversarial training but I am not.

Additional info:

  • the input x is word embedding. I’m not updating x. The reason why I put x as leaf nodes was in order to get the gradient (so that I can use it as a perturbation.)

Below is my implementation.

class Model():
    def update(self, batch):
        self.network.train()
        # forward propagate. These are answer span (NLP Question Answering span). predicted labels.
        start, end = self.network(batch) 
        # F(x, θ)
        loss = F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])
        # F(x+perturb, θ)
        loss_adv = self.adversarial_loss(batch, loss, y)
        loss_total = loss + loss_adv
        self.optimizer.zero_grad() 
        # jointly optimize F(x, θ) + F(x+perturb, θ)
        loss_total.backward(retain_graph=False)        
        self.optimizer.step()       

    def adversarial_loss(self, batch, loss, y):
        """
        Input: 
                - batch: batch that is shared with update method. 
                - loss : F(x, θ)
                - y    : y label (answer span)
        Output:
                - adversarial loss. F(x+perturb, θ). Scalar tensor. 
        """
        self.optimizer.zero_grad() # Remove accumulated grads from the previous batch.
        loss.backward(retain_graph=True) # backprop. I retained the graph for total_loss.
        
        # grad = d(loss)/d((x). Shape:[# of vocab, word emb dimension]
        grad = self.network.lexicon_encoder.embedding.weight.grad.data 
        grad.detach_() #**QUESTION: should it be loss.detach_()?**
        perturb = F.normalize(grad, p=2, dim =1) * 5.0 # 5 is the norm of perturbation. Hyperparam. 
        adv_embedding = self.network.lexicon_encoder.embedding.weight + perturb

        # cache the original embedding's weight (to revert it back later)
        original_emb = self.network.lexicon_encoder.embedding.weight.data
        self.network.lexicon_encoder.embedding.weight.data = adv_embedding 

        # forward propagate for F(x+perturb, θ)
        adv_start, adv_end, adv_pred = self.network(batch)
        
        # revert x back to the original state before the perturbation.
        self.network.lexicon_encoder.embedding.weight.data = original_emb 
        # switch off the embedding training. 
        self.network.lexicon_encoder.embedding.training = False
        
        return F.cross_entropy(adv_start, y[0]) + F.cross_entropy(adv_end, y[1])

Hope I could get some feedback on this. Thanks for reading!