I am new to PyTorch so have difficulty in understanding Gradient computation on multi-branch neural network. Someone, please explain how the loss will be backpropagated from output to back? Consider any arbitrary case that would make the explanation easier.
If by multi-branch network you mean something like a inception module. Consider the following example:
x = torch.ones(1, 2, requires_grad=True) # 1
y = torch.cat([x, x * 2]) # 2
y.sum().backward() # 3
x.grad # torch.tensor([3, 3]) # 4
(4) By convention, the gradient wrt to the loss itself (a scalar) is 1 (and also a scalar).
(3) When doing something like SumBackward, this value is expanded to the size of y. Because for each term y_i in y will have dloss/dy_i = dloss/dloss.
(2) Cat takes two inputs. For now lets call them a and b. Since the gradients wrt cat’s output has the same shape as cat’s output, and the gradients wrt a, b are the same as a and b, the way CatBackward is defined here is also very natural.
(1) Finally since x is used twice in (2), by the chain rule, the gradient wrt x should be summed. To see what gives rise to the chain rule, for this step, just consider a, b functions of x, and then define y = cat(a(x), b(x)).