How to check the gradients?

Hi all.
I want to check the gradient is reversed well or not.
Then how to check the gradients ?

class GradReverse(Function):
    def __init__(self, lambd):
        self.lambd = lambd

    def forward(self, x):
        return x.view_as(x)

    def backward(self, grad_output):
        return grad_output * -self.lambd


def grad_reverse(x, lambd=1.0):
    return GradReverse(lambd)(x)

class Discriminator(nn.Module):
    def __init__(self, channel, pool, classes=10):
        super(Discriminator, self).__init__()
        self.cls_fn = nn.Linear(channel, classes)

    def forward(self, x, reverse=False, eta=0.1):
        if reverse:
            x = grad_reverse(x, eta)
        x = self.cls_fn(x)
        return x
def train():
    discriminator = Discriminator()

    out = discriminator(input, reverse=True)
    loss = CE_criterion(out, label)

    out2 = discriminator(input, reverse=False)
    loss2 = CE_criterion(out2, label)

    loss.backward(retain_graph=True)
    loss2.backward()

Hi,

I guess you want to compare the gradient values with and without your reverse flag and make sure they are reversed?

Also you want to follow the instructions from https://pytorch.org/docs/stable/notes/extending.html on how to write the custom Function. The one you use is the old version (before 0.3) and is being removed in latest version.

1 Like

First of all, thanks for the answer.
According to your instructions, I’ve changed the code.
Could you check it please?

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return -grad_output

class Discriminator(nn.Module):
    def __init__(self, channel, pool, classes=10):
        super(Discriminator, self).__init__()
        self.cls_fn = nn.Linear(channel, classes)

    def forward(self, x, reverse=False):
        x = self.cls_fn(x)
        if reverse:
            x = GradReverse.apply(x)
        return x

Hi,

This looks good. The ctx.save_for_backward(x) in the forward is not really needed since you don’t use it in the bakward so it could be just removed. Otherwise all good.

1 Like
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return -grad_output

class Discriminator(nn.Module):
    def __init__(self, channel, pool, classes=10):
        super(Discriminator, self).__init__()
        self.cls_fn = nn.Linear(channel, classes)

    def forward(self, x, reverse=False):
        x = self.cls_fn(x)
        if reverse:
            x = GradReverse.apply(x)
        return x


def train():
    discriminator = Discriminator(3, 10)
    out = discriminator(input, reverse=True)
    ce_loss = CE_criterion(out, label)
    ...
    total_loss += ce_loss

    grads = {}
    def save_grad(name):
       def hook(grad):
          grads[name] = grad
        return hook

    ce_loss.register_hook(save_grad('ce_loss'))
    total_loss.backward()
    print(grads['ce_loss']) # 0.125
   

@albanD
Even if I call Discriminator() with reverse=True, the gradient of ce_loss is still positive value.
What’s wrong with my code ?

If you consider your network forward goes from top to bottom, the ce_loss is below the function that reverses the gradients. So it is expected not to be influenced by it.
You will see it’s effect on the gradients of the elements above it (like the output of cls_fn or input).