I’m working on something similar like
replacing some pytorch method/function , do some tracing , then replace them back.
most other function/method is ok besides torch.Tensor.getitem
here’s the example code
import torch
a=torch.randint(1,10,(15,))
b=torch.randint(1,10,(35,))
print(a[b]) # OK
print(b[a]) # OK
getitem = torch.Tensor.__getitem__
def new_getitem(self, key):
print(key)
torch.Tensor.__getitem__ = new_getitem
torch.Tensor.__getitem__ = getitem
print(a[b]) # OK
print(b[a]) # IndexError: too many indices for tensor of dimension 1
it seems the behavior of getitem changed after replacing, only accept index with larger size, why did this happened and how do I fix it?
thanks