I am trying to create a batched version of the method that I am writing and I wanted to compute the loss over the whole dataset and then optimize for the specific slices. An example of what i want to do can be seen in
import torch as th from torch.autograd import * x = Variable(th.arange(4), requires_grad=True) loss = th.sum(th.max(x ** 3, th.zeros(4))) print(th.autograd.grad(loss, x)) print(th.autograd.grad(loss, x[:2]))
where I wish the last print would give the derivative for the first two elements. So I need someway of slicing the data while preserving the graph. How can I do this, without changing the way that I compute the loss?