oasjd7
(oasjd7)
June 23, 2020, 11:51am
1
Hi all.
I want to check the gradient is reversed well or not.
Then how to check the gradients ?
class GradReverse(Function):
def __init__(self, lambd):
self.lambd = lambd
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return grad_output * -self.lambd
def grad_reverse(x, lambd=1.0):
return GradReverse(lambd)(x)
class Discriminator(nn.Module):
def __init__(self, channel, pool, classes=10):
super(Discriminator, self).__init__()
self.cls_fn = nn.Linear(channel, classes)
def forward(self, x, reverse=False, eta=0.1):
if reverse:
x = grad_reverse(x, eta)
x = self.cls_fn(x)
return x
def train():
discriminator = Discriminator()
out = discriminator(input, reverse=True)
loss = CE_criterion(out, label)
out2 = discriminator(input, reverse=False)
loss2 = CE_criterion(out2, label)
loss.backward(retain_graph=True)
loss2.backward()
albanD
(Alban D)
June 23, 2020, 3:08pm
2
Hi,
I guess you want to compare the gradient values with and without your reverse flag and make sure they are reversed?
Also you want to follow the instructions from https://pytorch.org/docs/stable/notes/extending.html on how to write the custom Function. The one you use is the old version (before 0.3) and is being removed in latest version.
1 Like
oasjd7
(oasjd7)
June 23, 2020, 4:09pm
3
First of all, thanks for the answer.
According to your instructions, I’ve changed the code.
Could you check it please?
class GradReverse(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_output):
return -grad_output
class Discriminator(nn.Module):
def __init__(self, channel, pool, classes=10):
super(Discriminator, self).__init__()
self.cls_fn = nn.Linear(channel, classes)
def forward(self, x, reverse=False):
x = self.cls_fn(x)
if reverse:
x = GradReverse.apply(x)
return x
albanD
(Alban D)
June 23, 2020, 5:22pm
4
Hi,
This looks good. The ctx.save_for_backward(x)
in the forward is not really needed since you don’t use it in the bakward so it could be just removed. Otherwise all good.
1 Like
oasjd7
(oasjd7)
June 24, 2020, 5:17am
5
class GradReverse(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_output):
return -grad_output
class Discriminator(nn.Module):
def __init__(self, channel, pool, classes=10):
super(Discriminator, self).__init__()
self.cls_fn = nn.Linear(channel, classes)
def forward(self, x, reverse=False):
x = self.cls_fn(x)
if reverse:
x = GradReverse.apply(x)
return x
def train():
discriminator = Discriminator(3, 10)
out = discriminator(input, reverse=True)
ce_loss = CE_criterion(out, label)
...
total_loss += ce_loss
grads = {}
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
ce_loss.register_hook(save_grad('ce_loss'))
total_loss.backward()
print(grads['ce_loss']) # 0.125
@albanD
Even if I call Discriminator() with reverse=True
, the gradient of ce_loss is still positive value.
What’s wrong with my code ?
albanD
(Alban D)
June 24, 2020, 3:31pm
6
If you consider your network forward goes from top to bottom, the ce_loss is below the function that reverses the gradients. So it is expected not to be influenced by it.
You will see it’s effect on the gradients of the elements above it (like the output of cls_fn
or input
).