So, I have this function:
def ode_residual(self, time: th.DoubleTensor, state: th.DoubleTensor) -> th.DoubleTensor: dtheta = th.autograd.grad(state.sum(), time, retain_graph=True) d2theta = th.autograd.grad(dtheta.sum(), time, retain_graph=True)
time and 2 tensors are concatenated as an input to an MLP.
time has requires_grad set to True and the other 2 variables do not. The shape of
(b, 1) and the other 2 tensors have size
(b, 2). The MLP outputs a tensor that is sliced to 3 tensors.
state is one of these tensors and has a shape of
(b, 1) where b is the batch size.
class MLP: def __init__(self, width: int, depth: int, pretrained:bool = False, dtype=th.double) -> None: super().__init__() self.layers = [nn.Linear(1+2+2, width, dtype=dtype)] self.layers.extend([nn.Linear(width, width, dtype=dtype) for _ in range(depth)]) self.layers.append(nn.Linear(width, 1+ 1+ 2, dtype=dtype)) self.layers = nn.Sequential(*self.layers) def output_to_states(self, output) -> Tuple[th.DoubleTensor, th.DoubleTensor, th.DoubleTensor]: """ Returns: state tensor 2 tensor 3 """ return output[:, 0:1], output[:, 1:2], output[:, 2:] def forward(self, t: th.DoubleTensor, x0: th.DoubleTensor, x1: th.DoubleTensor) -> th.DoubleTensor: x = th.cat([t, x0, x1], dim=1) return self.layers(x)
First, in the code above,
dtheta returns in a tensor of shape
(b, 1) that has all entries be the same value. Why is this? Shouldn’t all batch entries be independent and thus even thought we have a summation, the gradients will be different for each batch. Also, calculating for
dtheta2 results in the following error even though I set
retain_graph to True. Any help in answering these question would be appreciated.
File "c:\path\to\project\core\dynamics.py", line 93, in ode_residual d2theta = th.autograd.grad(dtheta.sum(), time, retain_graph=True) File "C:\path\to\project\.venv\lib\site-packages\torch\autograd\__init__.py", line 303, in grad return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn