I am trying to subclass the torch.Tensor class but am having some difficulties. The above code works great until you call the
clone methods. Both will return an instance of Tensor, not MyObject.
To improve the above code the best I can accomplish is this:
def __new__(cls, x, extra_data, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x, extra_data):
self.extra_data = extra_data
def clone(self, *args, **kwargs):
return MyObject(super().clone(*args, **kwargs), self.extra_data)
def to(self, *args, **kwargs):
new_obj = MyObject(, self.extra_data)
new_obj.data = super().to(*args, **kwargs)
This works except if
requires_grad=True. In that case the
to method will detach the object from the graph.
obj1 = MyObject([1, 2, 3], 'extra_data_123')
obj2 = obj1.to('cuda')
False but this:
t1 = torch.Tensor([1, 2, 3])
t2 = t1.to('cuda')
How can I improve the
to method to make it work like the Tensor class?
By the way, you may want to look into the
torch.tensor function and put that into the
__new__ method as well. As is, you cannot pass things like
torch.Tensor or this subclass.
It would be great if I could find where the
torch.tensor function is defined. Can someone tell me?