Computing output gradients for NTK

You could either reduce the ouputs to a scalar value and call .backward() on it or you could alternatively pass the gradients to the backward operation.
@albanD explains the reasoning behind this in this post.

1 Like