I am not able to figure out why by hand derivatives are not matching that of pytorch’s.
Equations:
h0 = 0.5
w_h = 2.0
w_y = 40.0
h1_raw = w_h * h0
h1 = tanh(h1_raw)
h2_raw = w_h * h1
h2 = tanh(h2_raw)
y2 = w_y * h2
find dy2/dw_h & dy2/dw_y ?
W_h.grad : pythorch autograd on Y2 backward
dy2_dwh: computed using analytic method
dw_h: computed using backward pass on computation graph (CS231 standford course)
Here is the code and output.
import torch
W_h = torch.tensor(2.0, requires_grad=True)
w_h = W_h.detach().item()
W_y = torch.tensor(40.0, requires_grad=True)
w_y = W_y.detach().item()
H0 = torch.tensor(0.5, requires_grad=False)
h0 = H0.item()
H1_raw = W_h * H0
h1_raw = H1_raw.item()
H1 = torch.tanh(H1_raw)
h1 = H1.detach().item()
H2_raw = W_h * H1
h2_raw = H2_raw.detach().item()
H2 = torch.tanh(H2_raw)
Y2 = W_y * H2
h2 = H2.detach().item()
Y2.backward()
# analytic method for derivatives
# dy/dw_y = h2 * dw_y/dw_y = h2 (product rule)
# dy/dw_h = h2 * dw_y/dw_h + w_y * dh2/dw_h = w_y * dh2/dw_h (product rule & w_y is contant for w_h)
# dh2/dw_h = 1 - h2_raw**2 * dh2_raw/dw_h (chain rule)
# dh2_raw/dw_h = h1 * dw_h/dw_h + w_h * dh1/dw_h = h1 + w_h * dh1/dw_h (product rule)
# dh1/dw_h = 1 - h1_raw**2 * dh1_raw/dw_h (chain rule)
# dh1_raw/dw_h = h0 * dw_h/dw_h + w_h * dh0/dw_h = h0 (product rule & h0 is constant)
dh0_dwh = 0 #because H0 is constant for w_h
dh1raw_dwh = h0 + w_h * dh0_dwh
dh1_dwh = (1 - h1_raw**2) * dh1raw_dwh
dh2raw_dwh = h1 + w_h * dh1_dwh
dh2_dwh = (1 - h2_raw**2) * dh2raw_dwh
dwy_dwh = 0
dy2_dwh = h2 * dwy_dwh + w_y * dh2_dwh
dh2_dwy = 0
dy2_dwy = h2 + w_y * dh2_dwy
# backward pass over computation graph
dy2 = 1
dw_y = h2 * 1
dh2 = w_y * 1
dh2_raw = (1 - h2_raw**2) * dh2
dh1 = w_h * dh2_raw
dh1_raw = (1 - h1_raw**2) * dh1
dw_h = h1 * dh2_raw + h0 * dh1_raw
print(W_h.grad, "<>", dy2_dwh, "<>", dw_h)
print(W_y.grad, "<>", dy2_dwy, "<>", dw_y)
tensor(8.1888) <> -40.21530288726289 <> -40.21530288726289
tensor(0.9093) <> 0.9092516899108887 <> 0.9092516899108887
Note at Desmos | Graphing Calculator, d/dx(40 * tanh(x * tanh(x * 0.5)))
at x=2 is 8ish