I would like to calculate the gradient on a sub-matrix. Following snippets demonstrate my use case:

```
A = torch.rand((3, 5), requires_grad=True)
target = torch.zeros((3, 5), requires_grad=True)
loss = nn.functional.mse_loss(A, target)
grad = torch.autograd.grad(loss, A)[0]
print(f'full grad shape = {grad.shape}')
grad = torch.autograd.grad(loss, A[:1, :])[0]
print(f'grad shape = {grad.shape}')
```

I can successfully calculate the gradients of loss w.r.t the matrix `A`

. However, it fails with sub-matrix.