Hi,
I am currently experimenting with automatic differentiation.
I wrote the following script, but the error signal returned from torch.tanh
is not what I expected.
I expect the error signal returned from tanh to be 0.9151. Because the input of torch.tanh
is 0.3 of tensor in the following script, so its derivative must be 1.0 - tanh(0.3)^2.0 = 0.9151
of tensor.
Please let me know if any part of my understanding is incorrect.
import torch
class LinearImpl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w):
y = x @ w
ctx.save_for_backward(x, w)
return y
@staticmethod
def backward(ctx, e):
x, w = ctx.saved_tensors
grad_x = grad_w = None
print(e)
if ctx.needs_input_grad[1]:
grad_w = x.t() @ e
if ctx.needs_input_grad[0]:
grad_x = e @ w.t()
return grad_x, grad_w
class Linear(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = torch.nn.Parameter(torch.zeros(in_dim, out_dim) + 0.3)
def forward(self, x):
return LinearImpl.apply(x, self.weight)
x = torch.zeros(2, 2, requires_grad=True) + 0.5
fc1 = Linear(2, 3)
fc2 = Linear(3, 2)
y = fc2(torch.tanh(fc1(x)))
y.sum().backward()