Exact meaning of grad_input and grad_output

It actually is a bit more complicated:

  • grad_output is the gradient of the loss w.r.t. the layer output. So if you have a layer l and do, say, y = l(x) ; loss = y.sum(); loss.backward(), you get the gradient of loss w.r.t. y.
  • grad_input are the inputs to the last operation in the layer. This may not quite be what you have expected… For linear layers, this is fairly complete, as the last op is torch.addmm multiplying the input with the weight and adding the bias. For other layers (e.g. do a Sequential, it’ll be the last op of the last layer, the inputs not even remotely related to the sequential layer’s inputs). You can see what will be used by looking at y.grad_fn.

So to be honest, I don’t know what the exact use case for that would be and I certainly cannot comment on the exact design choice for that, but you can see how a module hook is turned into a hook on grad_fn in the source of torch/nn/modules/module.py.

Best regards

Thomas

15 Likes