Grad returns repeating values for independent batch entries and doesn't want to second order derivative

So, I have this function:

  def ode_residual(self,
                 time: th.DoubleTensor,
                 state: th.DoubleTensor) -> th.DoubleTensor:

      dtheta = th.autograd.grad(state.sum(), time, retain_graph=True)[0]
      d2theta = th.autograd.grad(dtheta.sum(), time, retain_graph=True)[0]

where time and 2 tensors are concatenated as an input to an MLP. time has requires_grad set to True and the other 2 variables do not. The shape of time is (b, 1) and the other 2 tensors have size (b, 2). The MLP outputs a tensor that is sliced to 3 tensors. state is one of these tensors and has a shape of (b, 1) where b is the batch size.

class MLP:

    def __init__(self,
                 width: int, 
                 depth: int,
                 pretrained:bool = False,
                 dtype=th.double) -> None:
        super().__init__()

        self.layers = [nn.Linear(1+2+2, width, dtype=dtype)]
        self.layers.extend([nn.Linear(width, width, dtype=dtype) for _ in range(depth)])
        self.layers.append(nn.Linear(width, 1+ 1+ 2, dtype=dtype))

        self.layers = nn.Sequential(*self.layers)
    
    def output_to_states(self, output) -> Tuple[th.DoubleTensor, th.DoubleTensor, th.DoubleTensor]:

        """
        
        Returns:
            state
            tensor 2
            tensor 3
        """

        return output[:, 0:1], output[:, 1:2], output[:, 2:]

    def forward(self, 
                t: th.DoubleTensor, 
                x0: th.DoubleTensor, 
                x1: th.DoubleTensor) -> th.DoubleTensor:

        x = th.cat([t, x0, x1], dim=1)
        return self.layers(x)

First, in the code above, dtheta returns in a tensor of shape (b, 1) that has all entries be the same value. Why is this? Shouldn’t all batch entries be independent and thus even thought we have a summation, the gradients will be different for each batch. Also, calculating for dtheta2 results in the following error even though I set retain_graph to True. Any help in answering these question would be appreciated.

File "c:\path\to\project\core\dynamics.py", line 93, in ode_residual
    d2theta = th.autograd.grad(dtheta.sum(), time, retain_graph=True)[0]
  File "C:\path\to\project\.venv\lib\site-packages\torch\autograd\__init__.py", line 303, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Hi Josue!

MLP is basically a Sequential that consists solely of Linear layers with
no intervening non-linearities.
It therefore collapses, in effect, into a single
Linear (1+2+2, 1+2+2) (with a bunch of redundant parameters) regardless
of the values of width and depth.

So the time component of each batch element gets multiplied by the
corresponding components of the collapsed Linear, those terms are
summed (and the bias and terms from the “other 2 variables” are added).
The gradient is just this single multiplicative factor (from the collapsed
Linear) and the other terms added on don’t affect the gradient. Because
there is no non-additive “mixing” (no intervening non-linear “activations”),
the “other 2 variables” (and the various additive terms) don’t affect time’s
gradient, nor does the value of time itself affect its gradient, so the time
gradient is the same for all of the batch elements, even when the batch
elements themselves differ.

In the dtheta line you should be using create_graph = True rather than
retain_graph = True. This causes the first call to autograd.grad() to
create the graph that the second call to autograd.grad() will then use to
compute the second derivative. (Note that create_graph = True automatically
turns on retain_graph = True so that the newly-created graph will be preserved
for use by the second call to autograd.grad()

(Your posted code, in and of itself, does not need retain_graph = True in its
second call to autograd.grad().)

Best.

K. Frank

1 Like

Thank you so much! I knew of the problem with just having linear terms but for some reason I didn’t see it haha.

I have other calls to grad() outside of what I showed.