Tensor.__getitem__ error when change method back

I’m working on something similar like

replacing some pytorch method/function , do some tracing , then replace them back.

most other function/method is ok besides torch.Tensor.getitem

here’s the example code

import torch

a=torch.randint(1,10,(15,))

b=torch.randint(1,10,(35,))

print(a[b]) # OK

print(b[a]) # OK

getitem = torch.Tensor.__getitem__

def new_getitem(self, key):

    print(key)

torch.Tensor.__getitem__ = new_getitem

torch.Tensor.__getitem__ = getitem

print(a[b]) # OK

print(b[a]) # IndexError: too many indices for tensor of dimension 1

it seems the behavior of getitem changed after replacing, only accept index with larger size, why did this happened and how do I fix it?

thanks