I’m creating a new subclass using torch.Tensor and as mentioned in docs, overriding all the torch API using torch_function as below but still I’m getting above error. I checked torch.index_select is working if I run it directly on instance as torch.index_select(x, 0, torch.tensor([0,1]))
considering x
as a Vertex
class.
class Vertex(torch.Tensor):
def __init__(self, data):
'''
Theta shall be in radians. List is not supported yet.
'''
assert data.ndim in [1,2]
assert len(data) == 4 if data.ndim==1 else len(data[0])==4, "Please follow g2o format"
self.id = self[0]
self.x = self[1]
self.y = self[2]
self.theta = self[3]
super().__init__()
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
return super().__torch_function__(func, types, args, kwargs)
Any solution of this error?