I want to compute the gradient of a parameter, but for each batch element and output dimension individually.
So assume my model has output
y_pred of shape (b,m), then I want to compute
y_pred[i,j].backward() for each
i,j. I can do this naively looping over i and j. What is the most performant way to do this?