Is the gradient of the sum equal to the sum of the gradients for a neural network in pytorch?

Let’s suppose I have the code below and I want to calculate the jacobian of L, which is the prediction made by a neural network in Pytorch, L is of size nx1 where n is the number of samples in a mini batch. In order to avoid a for loop for each entry of L (n entries) to calculate the jacobian for each sample in the mini batch some codes I found just sum the n predictions of the neural network (L) with respect with the inputs and then calculate the gradient of the sum. First I can’t understand why is the gradient of the sum the same of the sum of the gradients for each sample in pytorch architecture. Second I tried both with the sum and with a for loop and the results diverge. Could it be due to numerical approximations or because the sum just doesn’t make sense?

The code is below, where both functions belong to a nn.module:

def forward(self, x):
        with torch.set_grad_enabled(True):
            def function(x,t):
                 self.n = n = x.shape[1]//2

                 qqd = x.requires_grad_(True)
                 L = self._lagrangian(qqd).sum()
                 J = grad(L, qqd, create_graph=True)[0]

        
def _lagrangian(self, qqd):
    x = F.softplus(self.fc1(qqd))
    x = F.softplus(self.fc2(x))
    x = F.softplus(self.fc3(x))
    L = self.fc_last(x)
    return L

The other way would be to perform a for loop for each sample in the mini batch and remove the sum in the jacobian. However the results are slightly different 10⁻5 to 10⁻6 for predictions. However these differences are higher for loss computations, which I suppose is due to error propagation in the MSE loss function.

Hi,

First I can’t understand why is the gradient of the sum the same of the sum of the gradients for each sample in pytorch architecture.

This is happens because of the linearity of the derivatives:
d/dqqd(L0 + L1 + L2) = dL0/dqqd + dL1/dqqd + dL2/dqqd
But this assumes that L0, L1 and L2 are independent. So if in your forward, the result for each batch actually depends on the result from the others, you cannot separate the equation as L0 will depends on L1 etc.

Also errors of the order of 1e-5 are expected when doing these things and are simply due to the non-associativity of floating point numbers: because you compute things in a different order, you will get small errors. This is expected.

Hi, huge thank you. Also, I’m trying to replicate a code from jax framework (https://jax.readthedocs.io/en/latest/). However, despite having everything exactly the same I got different results. Basically I have same weight initialization and same hyperparameters. The code both in pytorch and jax makes use of inverse or pseudo inverse. I know there are other discussions about reproducibility specially between Pytorch and TensorFlow. But do you have any ideia if it’s an intrinsic difference in the calculations between pytorch and jax or if it may be due a bug in my code(I already verified everything)?

It depends how different the results are.
If it is just small values then most likely numerical precision because of float operation ordering and different backends.
If it is bigger, you will have to try and reduce your example as much as possible to find where the problem is. Or share it in a new Topic to see if people can help you :slight_smile:

1 Like

Hi. Just one more thing. Even if qdd has shape n_sampesxn_inputs(neural network) does the gradient of the sum of L which has shape n_samples gives the same gradient of a for loop over L0/qqd_sample0, L1/qqd_sample1 and so on? I assume in this case there is an independency but in the MSE loss function the differences can be of order of magnitude of 10^0. I suppose the results diverge even more when n_samples(mini batch) increases. Is this due to error propagation? Then of course since there are so many little differences the convergence will be different. I’m asking this because despite having everything the same I don’t get the same results as in jax. At the end of the first epochs the results are pretty similar, however after some epochs the results diverge, even though in pytorch I get convergence faster. So there might be some error propagation due to the non associativity of float numbers. But shouldn’t it be presented in jax too? I mean those very small values below 10⁻5/10-6 are noise and this noise is different according to different operations or in this case between frameworks? Also, in the JAX code they don’t perform the sum of L. However using a for loop to calculate the jacobian for each output of the neural network for each sample in the mini batch would be ridiculous computationally expensive in pytorch, taht’s why I’m using the sum since all outputs of the neural netowrk (L1,L2,L3…) are independent of each other since they depend on different samples.

You need the independence of the computation for this to work. If you have it it will be fine.

MSE loss function the differences can be of order of magnitude of 10^0. I suppose the results diverge even more when n_samples(mini batch) increases. Is this due to error propagation?

Note that the 1e-6 error happen in one op. But then when you keep doing computation, this error is usually amplified by every operation that you perform. So for a deep network or after few iterations, this error can grow quite a lot.

At the end of the first epochs the results are pretty similar, however after some epochs the results diverge

This is totally expected when you use neural nets: they will end up converging to different places in the space of parameters.
But most neural network loss functions are well behaved enough that all of these optima that are found give similar loss value (even though you can have very different parameter values).
You can see this as having the same impact as sampling a different set of initial weights.

But shouldn’t it be presented in jax too?

This is present in jax as well. So you will have the same behavior of very small differences in one op that grow over time.

1 Like