Why do we need to provide inputs in torch.autograd.grad


I am trying to understand how torch.autograd.grad work. My understanding of how Pytorch performs differentiation is that each tensor is associated with a grad_fn that describes how the current tensor is derived from other tensors. If we were to write our custom functions in pytorch, we need to inherit from Function and implement the forward and backward functions. During the forward pass, we will need to use the ctx_manager to cache the inputs (and other relevant information) for the backward pass.

My question is if we are already caching the inputs in ctx_manager, why do we still need to provide inputs in torch.autograd.grad?

Thank you in advance for your help :slight_smile:

Why do we need to provide inputs in torch.autograd.grad

How would the grad() function otherwise know for what you want the derivative? E.g., say you have something like g(f(w*x+b)). Say g() is your output. Now, do you want to compute the derivative of g with respect to f, w, x, or b? How would the grad function know?

Thanks for your reply.

I guess I was buried in the idea that we usually only want the derivatives with respect to the learnable parameters but did not think that sometimes we might want derivatives with respect to intermediate results too. Thanks for the help :slight_smile: Another related question, in cases when two input tensors have identical shapes, how does grad know which inputs we are specifying? Does it look at grad_fn to do that?

yeah, a quite popular example may be saliency maps to compute the derivatives with respect to the pixel values, for example.

Can you explain your question a bit more, I don’t understand it correctly, I think. Do you mean if you have something like y = f(w^T . x) and you w and x are the same shape? Since you pass an input and an output vector, the function should know by memory reference which values or tensor you are referring to. Also, there’s the computation graph in the background that tracks the order of operations and probably has hooks to these tensors made during the “forward pass”.

Thanks :slight_smile: