The output of the network has a certain loss, however, I want to scale down the loss before propagating. (multitask learning)
Is
loss = loss * weight # weight can be 0.5 for example
loss.backward()
same as
logits.data.backward(grad * weight)
The output of the network has a certain loss, however, I want to scale down the loss before propagating. (multitask learning)
Is
loss = loss * weight # weight can be 0.5 for example
loss.backward()
same as
logits.data.backward(grad * weight)
Yes, loss scaling should also scale the gradients as seen in this small example:
model = models.resnet152()
x = torch.randn(2, 3, 224, 224)
loss1 = model(x).mean()
# standard
loss1.backward(retain_graph=True)
grads1 = [p.grad.clone() for p in model.parameters()]
# scale up
model.zero_grad()
loss2 = loss1 * 2
loss2.backward(retain_graph=True)
grads2 = [p.grad.clone() for p in model.parameters()]
# scale down
model.zero_grad()
loss3 = loss1 * 0.2
loss3.backward(retain_graph=True)
grads3 = [p.grad.clone() for p in model.parameters()]
# compare
for g1, g2, g3 in zip(grads1, grads2, grads3):
print('g2.sum()/g1.sum() {}'.format(g2.sum() / g1.sum()))
print('g3.sum()/g1.sum() {}'.format(g3.sum() / g1.sum()))