I am trying to subclass the torch.Tensor class but am having some difficulties. The above code works great until you call the to
or clone
methods. Both will return an instance of Tensor, not MyObject.
To improve the above code the best I can accomplish is this:
class MyObject(torch.Tensor):
@staticmethod
def __new__(cls, x, extra_data, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x, extra_data):
self.extra_data = extra_data
def clone(self, *args, **kwargs):
return MyObject(super().clone(*args, **kwargs), self.extra_data)
def to(self, *args, **kwargs):
new_obj = MyObject([], self.extra_data)
new_obj.data = super().to(*args, **kwargs)
return new_obj
This works except if requires_grad=True
. In that case the to
method will detach the object from the graph.
obj1 = MyObject([1, 2, 3], 'extra_data_123')
obj1.requires_grad_(True)
obj2 = obj1.to('cuda')
obj2.requires_grad
Returns False
but this:
t1 = torch.Tensor([1, 2, 3])
t1.requires_grad_(True)
t2 = t1.to('cuda')
t2.requires_grad
Returns True
.
How can I improve the to
method to make it work like the Tensor class?
By the way, you may want to look into the torch.tensor
function and put that into the __new__
method as well. As is, you cannot pass things like dtype
to torch.Tensor
or this subclass.
It would be great if I could find where the torch.tensor
function is defined. Can someone tell me?