How does the backward pass get plumbed in and other rudimentary loss function questions


I’ve implemented the Dice IoU function as a loss function as suggested by the Vnet paper. I’ve copied the last couple of lines from the backward function as generated by _make_function_class_criterion and I’ve implemented a separate loss function as a wrapper around the call to forward.

Since I’m just imitating what I see elsewhere there’s a few things I don’t understand about loss functions.

In class Function you set __call__ = _C._FunctionBase._do_forward - presumably that just calls down in to whatever forward function that a derived class implements. Given a call down in to the forward function in the training code that returns a tensor, how does the torch know what the corresponding backward function is when you call backward() on that tensor?
e.g.
loss = F.nll_loss(output, target)
loss.backward()

My second question is what the last two lines are doing. What does the view(*repeat(… do? And what does mul_(tensor.expand_as(…) do?

Last question is, does the DiceLoss class make sense? When I first started looking at adding a new loss function a few days ago it seemed overwhelming, but this seems like a rather trivial piece of code.

Thanks in advance.

Yes, Function’s __call__ will prepare the object, call the forward method (implemented by the user), and do some postprocessing of the results (e.g. wrap them in Variables). These two steps is when the graph is built, and the function object is saved in there, so it can be found during backward and asked to compute the derivative.

*repeat(...) is just Python syntax, nothing specific to PyTorch. Look for the docs in itertools package. expand_as is necessary to make sure that the multiplied tensors are of the same length. It will fake that the object on which it was called it in fact a bigger tensor with some tricks (no memory copy needed). Note that you can only expand along new dimensions at the front (pretend that the tensor has more dimensions than it really has), or along dims of size 1.