Difference between criterion.backward et tensor.backward

Hello all;

I can understand the use of:
loss.backward with loss = BCELoss()(pred, target), as it hook gradients in each leaf of the graph, so that the step() call implements the descent.

But I don’t understand the use of tensor.backward in the following code:

def generate_gradients(self, class, layer_name):
        activation = self.intermediate_activations[layer_name]
        logit = self.output[:, class]
        logit.backward(torch.ones_like(logit), retain_graph=True)
        # gradients = grad(logit, activation, retain_graph=True)[0]
        # gradients = gradients.cpu().detach().numpy()
        gradients = self.gradients.cpu().detach().numpy()
        return gradients

I’m trying to reimplement the TCAV paper with PyTorch. (TCAV ie Testing with Concept Activation Vectors), and the objective is to calculate gradients of inputs at a given layer.

Thank you very much


Writing logit.backward(torch.ones_like(logit), retain_graph=True) is equivalent to logit.sum().backward(retain_graph=True) in case that makes more sense for you?


When you call .backward pytorch automatically sets the starting gradient for the backwardpass to a scalar of 1 but if you want to set this starting gradient to something different than the default 1 you need to pass it into the .backward method itself (in this case the torch.ones_like(logit)) . As albanD said doing .sum().backward() would be the same .In this case the torch.ones_like(logit) is the derivative of the .sum() operation so it would probably make it cleaner and more redable to just call a .sum() on the logits ,so logit.sum().backward(retain_graph=True).

I hope that i was able to help you a bit and im sorry if something i said is wrong but i think it should be alright.

have a nice day :slight_smile:

if your question is why you call .backward on a tensor then the answer is a bit simpler.
Every tensor in a graph can be the starting point for a backwardpass. So when you call
.backward on the return of for example BCELoss() you dont call it on the loss object but on
the returned tensor itself (so youre always calling .backward on a tensor).