class F(torch.autograd.Function):
def forward(self, i):
self.save_for_backward(i)
return tr.FloatTensor([1])
def backward(self, c):
a = tr.zeros(10,10)
i = self.saved_tensors
a[i] = tr.ones(3,10) # Error!!!
return None
f = F()
f(tr.LongTensor([2,3,4])
But this gives the error
TypeError: indexing a tensor with an object of type torch.LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument
This is a weird message, Since given i has type tuple not LongTensor
If i change my code like below
def backward(self, c):
a = tr.zeros(10,10)
i = self.saved_tensors[0]
a[i] = tr.ones(3,10)
return None