Aloha Jonathan!
Your problem seems to be here.
This line, grad_input = grad_output.clone()
, looks like some
leftover code from some other autograd function that isn’t relevant
here.
The next line, grad_input = calcBackward(input)
, doesn’t use
grad_output
which is needed to implement the chain rule, if you will.
Because TanhControl
is a scalar function that just gets applied
element-wise to a tensor, you just need to multiply grad_output
by the derivative of tanh()
, element-wise:
grad_input = calcBackward(input) * grad_output
Here is a script that compares pytorch’s tanh()
with a tweaked
version of your TanhControl
and a version that uses
ctx.save_for_backward()
to gain (modest) efficiency by saving
tanh (input)
(rather than just input
) so that it doesn’t have to
recomputed it during backward()
:
import torch
print (torch.__version__)
_ = torch.manual_seed (2021)
class TanhControl (torch.autograd.Function):
@staticmethod
def forward (ctx, input):
ctx.save_for_backward (input)
return torch.tanh (input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return (1.0 - torch.tanh (input)**2.0) * grad_output
class TanhControlB (torch.autograd.Function):
@staticmethod
def forward (ctx, input):
th = torch.tanh (input)
ctx.save_for_backward (th)
return th
@staticmethod
def backward(ctx, grad_output):
th, = ctx.saved_tensors
return (1.0 - th**2.0) * grad_output
t1 = torch.randn (2, 3)
print ('t1 = ...')
print (t1)
t2 = t1.clone()
t3 = t1.clone()
t1.requires_grad = True
t2.requires_grad = True
t3.requires_grad = True
l1 = (torch.tanh (t1)**2).sum()
l2 = (TanhControl.apply (t2)**2).sum()
l3 = (TanhControlB.apply (t3)**2).sum()
print ('l1 =', l1)
print ('l2 =', l2)
print ('l3 =', l3)
l1.backward()
l2.backward()
l3.backward()
print ('t1.grad = ...')
print (t1.grad)
print ('t2.grad - t1.grad = ...')
print (t2.grad - t1.grad)
print ('t3.grad - t1.grad = ...')
print (t3.grad - t1.grad)
Here is its output:
1.7.1
t1 = ...
tensor([[ 2.2871, 0.6413, -0.8615],
[-0.3649, -0.6931, 0.9023]])
l1 = tensor(2.7625, grad_fn=<SumBackward0>)
l2 = tensor(2.7625, grad_fn=<SumBackward0>)
l3 = tensor(2.7625, grad_fn=<SumBackward0>)
t1.grad = ...
tensor([[ 0.0792, 0.7693, -0.7168],
[-0.6137, -0.7680, 0.6963]])
t2.grad - t1.grad = ...
tensor([[0., 0., 0.],
[0., 0., 0.]])
t3.grad - t1.grad = ...
tensor([[0., 0., 0.],
[0., 0., 0.]])
Best.
K. Frank