[Solved] Reverse gradients in backward pass


(Daniel Bartolomé) #1

Hello everyone,

I am working on building a DANN (Ganin et al. 2016) in PyTorch. This model is used for domain adaptation, and forces a classifier to only learn features that exist in two different domains, for the purpose of generalization across these domains. The DANN uses a Gradient Reversal layer to achieve this.

I have seen some suggestions on this forum on how to modify gradients manually. However, I found it difficult to apply in my case, as the gradients are reversed midway through the backward pass (see the image, the gradients are reversed once the backward pass reaches the feature extractor, through the GRL (grad reverse layer)).

Below is my take on this, though I am not sure if I have used the hook correctly. I would greatly appreciate some suggestions on how to use these hooks effectively!

    # this is the hook to reverse the gradients
    def grad_reverse(grad):
        return grad.clone() * -lambd
    lambd = 1

    # 2) train feature_extractor and domain_classifier on full batch x
    
    # reset gradients
    f_ext.zero_grad()
    d_clf.zero_grad()
    
    # calculate domain_classifier predictions on batch x
    d_out = d_clf(f_ext(x).view(batch_size, -1))
    
    # use normal gradients to optimize domain_classifier
    f_d_loss = d_crit(d_out, yd.float())
    f_d_loss.backward(retain_variables = True)
    d_optimizer.step()
    
    # use reversed gradients to optimize feature_extractor
    d_out.register_hook(grad_reverse)
    f_d_loss = d_crit(d_out, yd.float())
    f_d_loss.backward(retain_variables = True)
    f_optimizer.step()

Thanks,

Daniel


(Marcin Elantkowski) #2

Hi,

Wouldn’t it be most elegant to just have a GradientRevesrse Layer?

class GradReverse(Function):
    def forward(self, x):
        return x

    def backward(self, grad_output):
        return (-grad_output)

def grad_reverse(x):
    return GradReverse()(x)

(Daniel Bartolomé) #3

Indeed, it would be most elegant. The reason I used a hook was because it was suggested to use those when manually adjusting gradients. I think though for this problem that could be a nice solution.

I’m new to PyTorch so defining custom layers is something I have never done. After running your snippet, could I use the layer by simply treating it as a normal layer? That is:

class domain_classifier(nn.Module):
    def __init__(self):
        super(domain_classifier, self).__init__()
        self.fc1 = nn.Linear(1200, 100) 
        self.fc2 = nn.Linear(100, 1)
    
    def forward(self, x):
        x = grad_reverse(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.sigmoid(x)

(Marcin Elantkowski) #4

I think that should work.

Also, I just realized that Function should be defined in a different way in the newer versions of pytorch:


class GradReverse(Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

def grad_reverse(x):
    return GradReverse.apply(x)

The return x.view_as(x) seems to be necessary, because otherwise backward is not being called, I guess that as optimization Autograd checks if the Function modified the tensor to see if backward should be called.


(alain ) #5

Hi Daniel,

Did you manage to get this work? if yes, can you eventually share your code of the training + the class for domain classifier?

I am struggling to make mine work :weary:

thanks in advance
A.


(Daniel Bartolomé) #6

Hi Alain,

With the help of Marcin’s awesome solution (thanks!) I have been able to reproduce results from Bousmalis et al. (2016). I get 72% (77% in Bousmalis) on USPS with an MNIST-trained classifier, and then 86% (85% in Bousmalis) on USPS using the DANN with Marcin’s reversal layer. I had to adjust the code provided by Marcin a bit to make it work. I am using PyTorch version 0.1.12_2, so maybe that had to do with it. This did the trick:

class GradReverse(Function):
    def forward(self, x):
        return x.view_as(x)

    def backward(self, grad_output):
        return (grad_output * -lambd)

def grad_reverse(x):
    return GradReverse()(x)

class domain_classifier(nn.Module):
    def __init__(self):
        super(domain_classifier, self).__init__()
        self.fc1 = nn.Linear(1200, 100) 
        self.fc2 = nn.Linear(100, 1)
        self.drop = nn.Dropout2d(0.25)

    def forward(self, x):
        x = grad_reverse(x)
        x = F.leaky_relu(self.drop(self.fc1(x)))
        x = self.fc2(x)
        return F.sigmoid(x)

To train the model, standard PyTorch rules apply obviously. I did implement some learning rate and lambda parameter adjustments as proposed in the paper (Ganin et al., 2016). Here’s the code:

for i in range(num_epochs):
    source_gen = batch_gen(source_batches, source_idx, Xs_train, ys_train)
    target_gen = batch_gen(target_batches, target_idx, Xt_train, None)

    # iterate over batches
    for (xs, ys) in source_gen:
        
        # update lambda and learning rate as suggested in the paper
        p = float(j) / num_steps
        lambd = 2. / (1. + np.exp(-10. * p)) - 1
        lr = 0.01 / (1. + 10 * p)**0.75
        d_optimizer.lr = lr
        c_optimizer.lr = lr
        f_optimizer.lr = lr
        
        # exit if batch size incorrect, get next target batch
        if len(xs) != batch_size / 2:
            continue
        xt = next(target_gen)
        
        # concatenate source and target batch
        x = torch.cat([xs, xt], 0)
        
        # 1) train feature_extractor and class_classifier on source batch
        # reset gradients
        f_ext.zero_grad()
        c_clf.zero_grad()
        
        # calculate class_classifier predictions on batch xs
        c_out = c_clf(f_ext(xs).view(batch_size // 2, -1))
        
        # optimize feature_extractor and class_classifier on output
        f_c_loss = c_crit(c_out, ys.float())
        f_c_loss.backward(retain_variables = True)
        c_optimizer.step()
        f_optimizer.step()
        
        # 2) train feature_extractor and domain_classifier on full batch x
        # reset gradients
        f_ext.zero_grad()
        d_clf.zero_grad()
        
        # calculate domain_classifier predictions on batch x
        d_out = d_clf(f_ext(x).view(batch_size, -1))
        
        # optimize feature_extractor and domain_classifier with output
        f_d_loss = d_crit(d_out, yd.float())
        f_d_loss.backward(retain_variables = True)
        d_optimizer.step()
        f_optimizer.step()

Thanks again, Marcin, for your solution.
And Alain, I hope this helps you build the model.

Daniel


(Marcin Elantkowski) #7

I’m glad to see it worked! I was interested in this model myself, it’s cool to see that You were able to reproduce the results.

Cheers!


(alain ) #8

Thanks for the answer. This was indeed very helpful.

A.


(Harikrishna.Vydana) #10

class GradReverse(Function):
def forward(self, x):
return x.view_as(x)

def backward(self, grad_output):
    return (grad_output * -lambd)

def grad_reverse(x):
return GradReverse()(x)

running the above snippet of code gives the error

class GradReverse(Function):
TypeError: Error when calling the metaclass bases
module.init() takes at most 2 arguments (3 given)

but i gave single 2d tensor as input,
what are the two inputs it takes, x,lambd ?


(Daniel Bartolomé) #11

Hi,

In the implementation described above lambd is a global variable that is not passed to grad_reverse. If you passed lambd to this function this would be the reason you gave 1 argument too much.

I did not like to use lambd globally and that is why I added a constructor to the GradReverse class which asks for the lambd value.

class GradReverse(Function):
    def __init__(self, lambd):
        self.lambd = lambd

    def forward(self, x):
        return x.view_as(x)

    def backward(self, grad_output):
        return (grad_output * -self.lambd)

def grad_reverse(x, lambd):
    return GradReverse(lambd)(x)

If you want to change the lambda function dynamically during training, you can add a set_lambda method in the network:

def set_lambda(self, lambd):
    self.lambd = lambd

so you can change the lambda value by calling:

model.set_lambda(lambd)

Now, you can use the grad_reverse function as a normal layer in the network:

def forward(self, x):
    x = grad_reverse(x, self.lambd)

I hope this works for you.

Daniel


(Harikrishna.Vydana) #12

yes …it worked…