How to inherit sparse tensor

I try to inherit a sparse tensor.

while the following works:

class TestTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, data):
        tensor = torch.tensor([data])

        return tensor.as_subclass(cls)

using instead the sparse_coo_tensor function results in the following situation

class SparseMasksTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, indices, values, size, requires_grad=False):
        sparse_tensor = torch.sparse_coo_tensor(
            indices, values, size, requires_grad=requires_grad, is_coalesced=True
        )
        return sparse_tensor.as_subclass(cls)

sparse_tensor = SparseMasksTensor([[1,2,3],[0,2,3]], [1,2,3], (4,4))

Error:

NotImplementedError: Cannot access storage of SparseTensorImpl

How can I create inheritance correctly or should I stick to composition in this case?

How can I create inheritance correctly or should I stick to composition in this case?

Yeah I’d say that generally composition is preferred when working with PyTorch subclasses.

That’s sad, since I only extend the Tensor with some methods to make it special…

That’s sad, since I only extend the Tensor with some methods to make it special…

Sorry my original reply might’ve been misleading.

To be clear, you’d still be inheriting from a torch.Tensor, except what you want to do is (1) hold the sparse tensor as a field and (2) in your torch dispatch, define that upon every operator call, the operation is passed onto the inner sparse tensor and returned.

Some of the examples here might be helpful GitHub - albanD/subclass_zoo

For future readers: the actual behavior here described can be found in this file in the class TrivialTensorViaComposition

EDIT: Following the example I end up again with the issue from above:

return super().__new__(cls, elem)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Cannot access storage of SparseTensorImpl

In this case elem is the sparse implementation. And the error appears in the BaseTensor class, which I left unchanged.

While the above not worked, the approach from the documentation did kind of:

https://pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-wrapper-type

class MaskSparseTensor(torch.Tensor):
    def __init__(self, data, **kwargs):
        tensor = torch.as_tensor(data, **kwargs)
        self.sparse = tensor.to_sparse_coo()

    def __repr__(self):
        return f"data:\n{self.sparse}"

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        args = [getattr(a, "sparse", a) for a in args]
        ret = func(*args, **kwargs)
        return MaskSparseTensor(ret)

At least I am able now to create instances.

As soon as I try:

data1 = torch.randn(1000, 3, 4)
data2 = torch.randn(1000, 3, 4)

inst1 = MaskSparseTensor(data1)
inst2 = MaskSparseTensor(data2)

print(inst1 + inst2)

The well beloved NotImplementedError: Cannot access storage of SparseTensorImpl