I gave it a shot over and above what was discussed in the older discussion that you have pasted
Example
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
class MyObject(torch.Tensor):
@staticmethod
def __new__(cls, x, extra_data, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x, extra_data):
self.extra_data = extra_data
def clone(self, *args, **kwargs):
return MyObject(super().clone(*args, **kwargs), self.extra_data)
def to(self, *args, **kwargs):
new_obj = MyObject([], self.extra_data)
tempTensor=super().to(*args, **kwargs)
new_obj.data=tempTensor.data
new_obj.requires_grad=tempTensor.requires_grad
return(new_obj)
obj1 = MyObject([1, 2, 3], 'extra_data_123')
obj1.requires_grad_(True)
print(obj1.requires_grad)
obj2 = obj1.to('cuda')
print(obj2.requires_grad)
t1 = torch.Tensor([1, 2, 3])
t1.requires_grad_(True)
t2 = t1.to('cuda')
print(t2.requires_grad)
True
True
True
Hope this helps