obj1 = MyObject([1], 'extra_data_1')
obj2 = MyObject([2], 'extra_data_2')
print(type(obj1), type(obj2)) # <- <class '__main__.MyObject'> <class '__main__.MyObject'>
sum = obj1 + obj2
print(sum) # <- tensor([3.])
print(type(sum)) # <- <class '__main__.MyObject'>
obj3 = sum.clone() # <- AttributeError: 'MyObject' object has no attribute 'extra_data'
I try to fix the issue by adding a class variable:
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
class MyObject(torch.Tensor):
_extra_data = ''
@staticmethod
def __new__(cls, x, extra_data, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x, extra_data):
self._extra_data = extra_data
def clone(self, *args, **kwargs):
return MyObject(super().clone(*args, **kwargs), self._extra_data)
def to(self, *args, **kwargs):
new_obj = MyObject([], self._extra_data)
tempTensor=super().to(*args, **kwargs)
new_obj.data=tempTensor.data
new_obj.requires_grad=tempTensor.requires_grad
return(new_obj)
@property
def extra_data(self):
return self._extra_data
@extra_data.setter
def extra_data(self, d):
self._extra_data = d
obj1 = MyObject([1], 'extra_data_1')
obj2 = MyObject([2], 'extra_data_2')
print(type(obj1), type(obj2)) # <- <class '__main__.MyObject'> <class '__main__.MyObject'>
sum = obj1 + obj2
print(sum) # <- tensor([3.])
print(type(sum)) # <- <class '__main__.MyObject'>
obj3 = sum.clone() # <- Fine
print(obj3.extra_data) # print nothing
obj3.extra_data = 'obj3_extra_data'
print(obj3.extra_data) # print obj3_extra_data