the whole code just for reference:
'''
This script is written to test:
1)what happens if we differentiate through pytorch's backward pass and to see
if its consistent with the doing a second derivative
'''
import torch
import torch.nn as nn
from grad_test_sympy import symbolic_test
import sys
from pdb import set_trace as st
def test_does_backward_backward_match_second_derivative(experiment_type,grad_computation_type='dl_dw_zero_grad'):
'''
1)what happens if we differentiate through pytorch's backward pass and to see
if its consistent with the doing a second derivative
:param str experiment_type:
- grad_constant: this one is to make sure the sympy code and pytorch matches.
when g is constant there is no second call of backwards so its just
a sanity test to make sure things are working.
- grad_constant_but_use_backward_on_loss: this one calls
backward twice (one on the loss, second on J that depends on loss)
and the expected result for me it's to match the sympy code.
It doesn't match
-
'''
## variable declaration
w = torch.tensor([2.0], requires_grad=True)
x = torch.tensor([3.0], requires_grad=False)
y = torch.tensor([4.0], requires_grad=False)
x2 = torch.tensor([5.0], requires_grad=False)
y2 = torch.tensor([6.0], requires_grad=False)
if experiment_type == 'grad_constant':
## computes backard pass on J (i.e. dJ_dw) assuming g is constant
# compute g
if grad_computation_type == 'g_hard_coded':
grad = 2*(w*x-y)*x # assumption, loss = (w*x - y)**2
g = grad.item()
elif grad_computation_type == 'dl_dw_zero_grad':
# compute the backward pass dl_dw
loss = (w*x-y)**2
loss.backward()
g = w.grad.item()
# zero out the gradients on the orginal tensor
w.grad.zero_() # note this operation does NOT make g zero.
else:
ww = torch.randn(w.size(), requires_grad=True)
ww.data = w.clone()
loss = (ww*x-y)**2
loss.backward()
grad = ww.grad
g = grad.item()
assert( (2*(w*x-y)*x).item() == grad.item() )
# compute w_new
w_new = w - (g+w**2) * g
# compute final loss J
J = (w_new + x2 + y2)**2
# computes derivative of J
J.backward()
#dw_new_dw = w_new.grad.item()
dJ_dw = w.grad.item()
elif experiment_type == 'grad_constant_but_use_backward_on_loss':
## computes backard pass on J (i.e. dJ_dw) but g is was backwarded passed already
# compute g
loss = (w*x-y)**2
loss.backward()
#g = w.grad.item() # dl_dw
g = w.grad # dl_dw
g.requires_grad = True
print(g)
# compute w_new
w_new = w - (g+w**2) * g
# compute final loss J
J = (w_new + x2 + y2)**2
# computes derivative of J
J.backward()
#dw_new_dw = w_new.grad.item()
dJ_dw = w.grad.item()
elif experiment_type == 'grad_analytic_only_backward_on_J':
## computes backard pass on J (i.e. dJ_dw) assuming g is constant
# compute g
grad = 2*(w*x-y)*x # assumption, loss = (w*x - y)**2
g = grad
# compute w_new
w_new = w - (g+w**2) * g
# compute final loss J
J = (w_new + x2 + y2)**2
# computes derivative of J
J.backward()
#dw_new_dw = w_new.grad.item()
dJ_dw = w.grad.item()
else:
raise ValueError(f'This: experiment_type={experiment_type}, is not a valid test')
##
print('---- test results ----')
print(f'experiment_type = {experiment_type}')
print('-- Pytorch results')
print(f'g = {g}')
#print(f'dw_new_dw = {dw_new_dw}')
print(f'dJ_dw = {dJ_dw}')
print('-- Sympy results')
g, dw_new_dw, dJ_dw = symbolic_test(experiment_type)
print(f'g_SYMPY = {g}')
#print(f'dw_new_dw_SYMPY = {dw_new_dw}')
print(f'dJ_dw_SYMPY = {dJ_dw}')
if __name__ == '__main__':
test_does_backward_backward_match_second_derivative(experiment_type='grad_constant')
#test_does_backward_backward_match_second_derivative(experiment_type='grad_constant_but_use_backward_on_loss')
#test_does_backward_backward_match_second_derivative(experiment_type='grad_analytic_only_backward_on_J')