How to access CrossEntropyLoss() gradient?

I want to modify the tensor that stores the CrossEntropyLoss() gradient, that is, P(i)-T(i). Where is it stored and how do I access it?

code:

input = torch.randn(3, 5, requires_grad=True)
input.register_hook(lambda x: print(" \n input hook: ",x))
print(input)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target)

criterion = nn.CrossEntropyLoss()
criterion.requires_grad = True
loss0 = criterion(input,target)
loss0.register_hook(lambda x: print(" \n loss0 hook: ",x))
print("before backward loss0.grad :",loss0.grad)
print("loss0 :",loss0)
loss0.backward()
print("after backward loss0.grad :",loss0.grad)

output:

tensor([[-0.6149, -0.8179,  0.6084, -0.2837, -0.5316],
        [ 1.7246,  0.5348,  1.3646, -0.7148, -0.3421],
        [-0.3478, -0.6732, -0.7610, -1.0381, -0.5570]], requires_grad=True)
tensor([4, 1, 0])
before backward loss0.grad : None
loss0 : tensor(1.7500, grad_fn=<NllLossBackward>)
 
 loss0 hook:  tensor(1.)
 
 input hook:  tensor([[ 0.0433,  0.0354,  0.1472,  0.0603, -0.2862],
        [ 0.1504, -0.2876,  0.1050,  0.0131,  0.0190],
        [-0.2432,  0.0651,  0.0597,  0.0452,  0.0732]])
after backward loss0.grad : None

The gradient is input.grad
What do you understand by loss.grad?
input.grad is gradient of loss wrt input which is the cross entropy gradient.
there is no loss.grad as it is not involved in further opts

import torch

input = torch.randn(3, 5, requires_grad=True)
input.register_hook(lambda x: print(" \n input hook: ",x))
print(input)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target)

criterion = torch.nn.CrossEntropyLoss()
criterion.requires_grad = True
loss0 = criterion(input,target)
loss0.register_hook(lambda x: print(" \n loss0 hook: ",x))
loss1 = 1*loss0
print("before backward loss0.grad :",loss0.grad)
print("loss0 :",loss0)
loss1.backward()
print("after backward loss0.grad :",loss0.grad)

what you can do is to register a backward hook in criterion rather than inspecting tensors or just do the identity op multiplying by 1 to see each term grad

1 Like