Backward step for non-scalar


I want to compute the gradient of a parameter, but for each batch element and output dimension individually.

So assume my model has output y_pred of shape (b,m), then I want to compute y_pred[i,j].backward() for each i,j. I can do this naively looping over i and j. What is the most performant way to do this?

Cheers, Fabian!