Pytorch autograd not matching with computed derivatives?

I am not able to figure out why by hand derivatives are not matching that of pytorch’s.

Equations:

h0 = 0.5
w_h = 2.0
w_y = 40.0
h1_raw = w_h * h0
h1 = tanh(h1_raw)
h2_raw = w_h * h1
h2 = tanh(h2_raw)
y2 = w_y * h2

find dy2/dw_h & dy2/dw_y ?

W_h.grad : pythorch autograd on Y2 backward
dy2_dwh: computed using analytic method
dw_h: computed using backward pass on computation graph (CS231 standford course)

Here is the code and output.

import torch

W_h = torch.tensor(2.0, requires_grad=True)
w_h = W_h.detach().item()

W_y = torch.tensor(40.0, requires_grad=True)
w_y = W_y.detach().item()

H0 = torch.tensor(0.5, requires_grad=False)
h0 = H0.item()

H1_raw = W_h * H0
h1_raw = H1_raw.item()
H1 = torch.tanh(H1_raw)
h1 = H1.detach().item()

H2_raw = W_h * H1
h2_raw = H2_raw.detach().item()
H2 = torch.tanh(H2_raw)
Y2 = W_y * H2
h2 = H2.detach().item()

Y2.backward()

# analytic method for derivatives
#   dy/dw_y = h2 * dw_y/dw_y = h2 (product rule)
#   dy/dw_h = h2 * dw_y/dw_h + w_y * dh2/dw_h = w_y * dh2/dw_h (product rule & w_y is contant for w_h)
#   dh2/dw_h = 1 - h2_raw**2 * dh2_raw/dw_h (chain rule)
#   dh2_raw/dw_h = h1 * dw_h/dw_h + w_h * dh1/dw_h = h1 + w_h * dh1/dw_h (product rule)
#   dh1/dw_h = 1 - h1_raw**2 * dh1_raw/dw_h (chain rule)
#   dh1_raw/dw_h = h0 * dw_h/dw_h + w_h * dh0/dw_h = h0 (product rule & h0 is constant)
dh0_dwh = 0 #because H0 is constant for w_h
dh1raw_dwh = h0 + w_h * dh0_dwh
dh1_dwh = (1 - h1_raw**2) * dh1raw_dwh
dh2raw_dwh = h1 + w_h * dh1_dwh
dh2_dwh = (1 - h2_raw**2) * dh2raw_dwh
dwy_dwh = 0
dy2_dwh = h2 * dwy_dwh + w_y * dh2_dwh

dh2_dwy = 0
dy2_dwy = h2 + w_y * dh2_dwy

# backward pass over computation graph
dy2 = 1
dw_y = h2 * 1
dh2 = w_y * 1
dh2_raw = (1 - h2_raw**2) * dh2
dh1 = w_h * dh2_raw
dh1_raw = (1 - h1_raw**2) * dh1
dw_h = h1 * dh2_raw + h0 * dh1_raw

print(W_h.grad, "<>", dy2_dwh, "<>", dw_h)
print(W_y.grad, "<>", dy2_dwy, "<>", dw_y)

tensor(8.1888) <> -40.21530288726289 <> -40.21530288726289
tensor(0.9093) <> 0.9092516899108887 <> 0.9092516899108887

Note at Desmos | Graphing Calculator, d/dx(40 * tanh(x * tanh(x * 0.5))) at x=2 is 8ish

Hi Nitesh!

I haven’t looked at the details of your code, but it appears that your analytic
formula for the derivative of tanh() is incorrect.

Best.

K. Frank

I am applying the chain rule,

  f(g(w_h) = f'(g) . g'(w_h) where f = tanh() = h1 & g = h1_raw
   f' = 1 - f*f = 1 - h1*h1
   g' = dh1_raw/dw_h

I realized what I did wrong. Thank you @KFrank

buggy code

dh1_dwh = (1 - h1_raw**2) * dh1raw_dwh
dh2raw_dwh = h1 + w_h * dh1_dwh
dh2_dwh = (1 - h2_raw**2) * dh2raw_dwh
dh2_raw = (1 - h2_raw**2) * dh2
dh1 = w_h * dh2_raw
dh1_raw = (1 - h1_raw**2) * dh1

correction

dh1_dwh = (1 - h1**2) * dh1raw_dwh
dh2raw_dwh = h1 + w_h * dh1_dwh
dh2_dwh = (1 - h2**2) * dh2raw_dwh
dh2_raw = (1 - h2**2) * dh2
dh1 = w_h * dh2_raw
dh1_raw = (1 - h1**2) * dh1