Consider the example below. I’m trying to understand why the two calls aren’t identical (the second call returns a zero tensor instead of the proper gradients). I feel like I’m missing something very fundamental here. Thanks!
import torch
from torch.func import jacrev
def comp(x,y):
return x*x*y;
def comp2(x,y):
return torch.sum(comp(x,y))
def comp3(x,y):
return torch.sum(y)
R = torch.tensor([1.0,2.0,3.0],requires_grad=True)
z = torch.tensor([5.0])
print(jacrev(comp2)(R,z))
print(jacrev(comp3)(R,comp(R,z)))