I am trying to create a batched version of the method that I am writing and I wanted to compute the loss over the whole dataset and then optimize for the specific slices. An example of what i want to do can be seen in

```
import torch as th
from torch.autograd import *
x = Variable(th.arange(4), requires_grad=True)
loss = th.sum(th.max(x ** 3, th.zeros(4)))
print(th.autograd.grad(loss, x))
print(th.autograd.grad(loss, x[:2]))
```

where I wish the last print would give the derivative for the first two elements. So I need someway of slicing the data while preserving the graph. How can I do this, without changing the way that I compute the loss?