Sorry if this is a duplicate, but all of the tutorials I’ve found on (reverse-mode) AD is under a single datapoint, and I’m curious how the issues that arise from minibatches are fixed. For example, how autograd knows that a function that parallelizes over the minibatch elements has nonzero jacobians only for the same minibatch indices. I can come up with fixes on my own, but I’m curious about how pytorch goes about this.
I’m not sure to understand exactly your question. Are you talking about autograd acting on a composition of functions parallelized over the minibatch elements (i.e. most models that don’t use BatchNorm), leading to a single scalar loss of which you want to compute the gradient?
In that case, during the backward pass, you do have some block-diagonal jacobian, mathematically (the jacobians of the outputs of each layer with respect to its inputs). But autograd doesn’t materialize those jacobians: it computes the result of a vjp (vector-jacobian product), to do each step of the chain rule, but without wasting memory: what’s important is that the result of the vjp is correct, not that internally it actually multiplies a vector by a materialized jacobian.
I can come up with fixes on my own
What kind of fixes are you talking about?