Subclassing torch.LongTensor

Related to Subclassing torch.Tensor
But I am trying to extend tensor with dtype=torch.int64
My code like this:

class MyTensor(torch.LongTensor):
    def __new__(cls, data, stats, *args, **kwargs):
        return super().__new__(cls, data, *args, **kwargs)
    def __init(self, stats, *args, **kwargs):
        self._stats = stats

x = MyTensor([0,1,2], 3)

This works for subclassing torch.Tensor, however it not work for torch.LongTensor
TypeError: type ‘torch.LongTensor’ is not an acceptable base type
How to subclass torch.LongTensor

Don’t use torch.LongTensor, it’s not a proper class/type but a hack!
The proper thing is to subclass Tensor and have the dtype set fixed to torch.long.

I have tried to subclass torch.Tensor, however dtype cannot be param of I could not find out how to init dtype
Here is what I have tried

def __new__(self, data, stats):
    tensor = torch.as_tensor(data, dtype=torch.long)
    return torch.Tensor.__new__(self, tensor)   # Error: expected Float (got Long)

or like

def __new__(self, data, stats):
    tensor = torch.as_tensor(data, dtype=torch.long)
    return tensor
def __init__(self, data, stats):
    self.stats = stats

x = MyTensor([1,2,3],3)   # Actually it returns a torch.LongTensor not MyTensor object
x.stats   # Error

You’re diving right into the tricky bits, but:

  • __new__ is a class method implicitly (don’t ask).
  • You can use nn.Parameter as an example. Adapted this gives:
class MyTensor(torch.Tensor):
     def __new__(cls, data, stats, requires_grad=False):
         data = torch.as_tensor(data, dtype=torch.long)
         tensor = torch.Tensor._make_subclass(cls, data, requires_grad)
         tensor.stats = stats
         return tensor

So I’m still not 100% sure what you’re trying to achieve, but if it is a subclass that is similar to nn.Parameter, this would be what I’d use.

Best regards