[Solved] Reverse gradients in backward pass

Hi Alain,

With the help of Marcin’s awesome solution (thanks!) I have been able to reproduce results from Bousmalis et al. (2016). I get 72% (77% in Bousmalis) on USPS with an MNIST-trained classifier, and then 86% (85% in Bousmalis) on USPS using the DANN with Marcin’s reversal layer. I had to adjust the code provided by Marcin a bit to make it work. I am using PyTorch version 0.1.12_2, so maybe that had to do with it. This did the trick:

class GradReverse(Function):
    def forward(self, x):
        return x.view_as(x)

    def backward(self, grad_output):
        return (grad_output * -lambd)

def grad_reverse(x):
    return GradReverse()(x)

class domain_classifier(nn.Module):
    def __init__(self):
        super(domain_classifier, self).__init__()
        self.fc1 = nn.Linear(1200, 100) 
        self.fc2 = nn.Linear(100, 1)
        self.drop = nn.Dropout2d(0.25)

    def forward(self, x):
        x = grad_reverse(x)
        x = F.leaky_relu(self.drop(self.fc1(x)))
        x = self.fc2(x)
        return F.sigmoid(x)

To train the model, standard PyTorch rules apply obviously. I did implement some learning rate and lambda parameter adjustments as proposed in the paper (Ganin et al., 2016). Here’s the code:

for i in range(num_epochs):
    source_gen = batch_gen(source_batches, source_idx, Xs_train, ys_train)
    target_gen = batch_gen(target_batches, target_idx, Xt_train, None)

    # iterate over batches
    for (xs, ys) in source_gen:
        
        # update lambda and learning rate as suggested in the paper
        p = float(j) / num_steps
        lambd = 2. / (1. + np.exp(-10. * p)) - 1
        lr = 0.01 / (1. + 10 * p)**0.75
        d_optimizer.lr = lr
        c_optimizer.lr = lr
        f_optimizer.lr = lr
        
        # exit if batch size incorrect, get next target batch
        if len(xs) != batch_size / 2:
            continue
        xt = next(target_gen)
        
        # concatenate source and target batch
        x = torch.cat([xs, xt], 0)
        
        # 1) train feature_extractor and class_classifier on source batch
        # reset gradients
        f_ext.zero_grad()
        c_clf.zero_grad()
        
        # calculate class_classifier predictions on batch xs
        c_out = c_clf(f_ext(xs).view(batch_size // 2, -1))
        
        # optimize feature_extractor and class_classifier on output
        f_c_loss = c_crit(c_out, ys.float())
        f_c_loss.backward(retain_variables = True)
        c_optimizer.step()
        f_optimizer.step()
        
        # 2) train feature_extractor and domain_classifier on full batch x
        # reset gradients
        f_ext.zero_grad()
        d_clf.zero_grad()
        
        # calculate domain_classifier predictions on batch x
        d_out = d_clf(f_ext(x).view(batch_size, -1))
        
        # optimize feature_extractor and domain_classifier with output
        f_d_loss = d_crit(d_out, yd.float())
        f_d_loss.backward(retain_variables = True)
        d_optimizer.step()
        f_optimizer.step()

Thanks again, Marcin, for your solution.
And Alain, I hope this helps you build the model.

Daniel

7 Likes