How can I tell the difference between a regular tensor and a nested tensor?

With the recently introduced nested tensors, if I create a nested tensor:

import torch

a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)

and check its type:

type(nt)
torch.Tensor

it just appears to be a regular Tensor object. If in a code, I wanted to differentiate between a nested tensor and a regular tensor, how could I go about doing it as both type(nt) == torch.Tensor and isinstance(nt, torch.Tensor) will return True?

One way I thought of is to use the fact that (currently) the size method behaves differently, i.e., for a nested tensor it requires an argument otherwise it will raise a RuntimeError. So, I could do:

def is_nested_tensor(nt):
    if not isinstance(nt, torch.Tensor):
        return False

    try:
        # try calling size without an argument
        nt.size()
        return False
    except RuntimeError:
        return True

    return False

But is there a simpler way?

You could check the is_nested attribute:

a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)

print(a.is_nested)
# False

print(nt.is_nested)
# True
1 Like

Great, thanks. That’s just what I was after!