Loss computation in multi branch neural network in PyTorch

I am new to PyTorch so have difficulty in understanding Gradient computation on multi-branch neural network. Someone, please explain how the loss will be backpropagated from output to back? Consider any arbitrary case that would make the explanation easier.


If by multi-branch network you mean something like a inception module. Consider the following example:

import torch
x = torch.ones(1, 2, requires_grad=True)  # 1
y = torch.cat([x, x * 2])  # 2
y.sum().backward()  # 3
x.grad  # torch.tensor([3, 3])  # 4

(4) By convention, the gradient wrt to the loss itself (a scalar) is 1 (and also a scalar).
(3) When doing something like SumBackward, this value is expanded to the size of y. Because for each term y_i in y will have dloss/dy_i = dloss/dloss.
(2) Cat takes two inputs. For now lets call them a and b. Since the gradients wrt cat’s output has the same shape as cat’s output, and the gradients wrt a, b are the same as a and b, the way CatBackward is defined here is also very natural.
(1) Finally since x is used twice in (2), by the chain rule, the gradient wrt x should be summed. To see what gives rise to the chain rule, for this step, just consider a, b functions of x, and then define y = cat(a(x), b(x)).