I am trying to extract gradients and later use them as data (so have them non-trainable). For that I am having to call .backward() twice. Doing that seems to mess up with my code as noted by mother question ( What does doing .backward twice do?). However, for the code to work how I intend it (without the issue of calling backward twice) I am having to do clone the model parameters:
ww = torch.randn(w.size(), requires_grad=True)
ww.data = w.clone()
loss = (ww*x-y)**2
loss.backward()
grad = ww.grad
assert( (2*(w*x-y)*x).item() == grad.item() )
in this silly toy example its not an issue but I am worried that this might become an issue later. Anyone know how to get the gradients without having to clone the data and create extra vectors?
all the code:
if experiment_type == 'grad_constant':
## computes backard pass on J (i.e. dJ_dw) assuming g is constant
# compute g
if g_hard_coded:
grad = 2*(w*x-y)*x # assumption, loss = (w*x - y)**2
else:
ww = torch.randn(w.size(), requires_grad=True)
ww.data = w.clone()
loss = (ww*x-y)**2
loss.backward()
grad = ww.grad
assert( (2*(w*x-y)*x).item() == grad.item() )
g = grad.item()
# compute w_new
w_new = w - (g+w**2) * g
# compute final loss J
J = (w_new + x2 + y2)**2
# computes derivative of J
J.backward()
#dw_new_dw = w_new.grad.item()
dJ_dw = w.grad.item()
Iām not sure what the question is here.
As a note, you should never have to use .data. It should either be replaced by .detach() or with torch.no_grad(): (there are other posts in this forum that discuss this).
Hi AlbanD. What I am trying to do is make sure that the gradient information w.grad is treated as a number if the computation graph and not as a variable that one can differentiate. i.e. treated as a constant. Does that make sense?
Would .data still make things go wrong? In what way? I tried some code and printed values and the differentiation worked as I would have expected so I am not sure what you are suggesting or how it fixes my scenario.
'''
This script is written to test:
1)what happens if we differentiate through pytorch's backward pass and to see
if its consistent with the doing a second derivative
'''
import torch
import torch.nn as nn
from grad_test_sympy import symbolic_test
import sys
from pdb import set_trace as st
def test_does_backward_backward_match_second_derivative(experiment_type,grad_computation_type='dl_dw_zero_grad'):
'''
1)what happens if we differentiate through pytorch's backward pass and to see
if its consistent with the doing a second derivative
:param str experiment_type:
- grad_constant: this one is to make sure the sympy code and pytorch matches.
when g is constant there is no second call of backwards so its just
a sanity test to make sure things are working.
- grad_constant_but_use_backward_on_loss: this one calls
backward twice (one on the loss, second on J that depends on loss)
and the expected result for me it's to match the sympy code.
It doesn't match
-
'''
## variable declaration
w = torch.tensor([2.0], requires_grad=True)
x = torch.tensor([3.0], requires_grad=False)
y = torch.tensor([4.0], requires_grad=False)
x2 = torch.tensor([5.0], requires_grad=False)
y2 = torch.tensor([6.0], requires_grad=False)
if experiment_type == 'grad_constant':
## computes backard pass on J (i.e. dJ_dw) assuming g is constant
# compute g
if grad_computation_type == 'g_hard_coded':
grad = 2*(w*x-y)*x # assumption, loss = (w*x - y)**2
g = grad.item()
elif grad_computation_type == 'dl_dw_zero_grad':
# compute the backward pass dl_dw
loss = (w*x-y)**2
loss.backward()
g = w.grad.item()
# zero out the gradients on the orginal tensor
w.grad.zero_() # note this operation does NOT make g zero.
else:
ww = torch.randn(w.size(), requires_grad=True)
ww.data = w.clone()
loss = (ww*x-y)**2
loss.backward()
grad = ww.grad
g = grad.item()
assert( (2*(w*x-y)*x).item() == grad.item() )
# compute w_new
w_new = w - (g+w**2) * g
# compute final loss J
J = (w_new + x2 + y2)**2
# computes derivative of J
J.backward()
#dw_new_dw = w_new.grad.item()
dJ_dw = w.grad.item()
elif experiment_type == 'grad_constant_but_use_backward_on_loss':
## computes backard pass on J (i.e. dJ_dw) but g is was backwarded passed already
# compute g
loss = (w*x-y)**2
loss.backward()
#g = w.grad.item() # dl_dw
g = w.grad # dl_dw
g.requires_grad = True
print(g)
# compute w_new
w_new = w - (g+w**2) * g
# compute final loss J
J = (w_new + x2 + y2)**2
# computes derivative of J
J.backward()
#dw_new_dw = w_new.grad.item()
dJ_dw = w.grad.item()
elif experiment_type == 'grad_analytic_only_backward_on_J':
## computes backard pass on J (i.e. dJ_dw) assuming g is constant
# compute g
grad = 2*(w*x-y)*x # assumption, loss = (w*x - y)**2
g = grad
# compute w_new
w_new = w - (g+w**2) * g
# compute final loss J
J = (w_new + x2 + y2)**2
# computes derivative of J
J.backward()
#dw_new_dw = w_new.grad.item()
dJ_dw = w.grad.item()
else:
raise ValueError(f'This: experiment_type={experiment_type}, is not a valid test')
##
print('---- test results ----')
print(f'experiment_type = {experiment_type}')
print('-- Pytorch results')
print(f'g = {g}')
#print(f'dw_new_dw = {dw_new_dw}')
print(f'dJ_dw = {dJ_dw}')
print('-- Sympy results')
g, dw_new_dw, dJ_dw = symbolic_test(experiment_type)
print(f'g_SYMPY = {g}')
#print(f'dw_new_dw_SYMPY = {dw_new_dw}')
print(f'dJ_dw_SYMPY = {dJ_dw}')
if __name__ == '__main__':
test_does_backward_backward_match_second_derivative(experiment_type='grad_constant')
#test_does_backward_backward_match_second_derivative(experiment_type='grad_constant_but_use_backward_on_loss')
#test_does_backward_backward_match_second_derivative(experiment_type='grad_analytic_only_backward_on_J')
For " the gradient information w.grad is treated as a number if the computation graph and not as a variable that one can differentiate" you want to use w.grad.detach() for further computations to achieve that.
.data will work but can be misleading as it does not allow the autograd to perform all the correctness checks. Which means that you can introduce subtle bugs that will make the gradients wrong. You cannot introduce such bugs with .detach() and torch.no_grad().
After some thought I think I am confused about your suggestion because I was thinking how gradients work mathematically and they are based on partial derivatives. Even if w.grad.detach() is called, when we do a partial derivative with respect to an independent variable (i.e. a leaf node in the computation graph) all other variables are constants. Thus, wether the gradient variable (i.e. w.grad) collects gradients or not I would have expected is irrelevant.
But your comment seems to suggest otherwise. Why is that the case? What part of my reasoning is wrong?