Pytorch 1.7.0, AttributeError: 'torch.Size' object has no attribute 'dtype' when using __torch_function__

The following code worked before the 1.7.0 update, and now it fails.

    @classmethod
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        test = func(*args, **kwargs)
        print(test.dtype)
        return TestTensor(test.float()) # Fails if the above print statement is commented out.

results in this error message:


     25         test = func(*args, **kwargs)
---> 26         print(test.dtype)
     27         return TestTensor(test.float())
     28 AttributeError: 'torch.Size' object has no attribute 'dtype'

How do I fix this?

Turns out it’s because of a change in that’s listed in the docs here: Extending PyTorch — PyTorch 2.1 documentation, but not in the release notes.

One should be careful within __torch_function__ for subclasses to always call super().__torch_function__(func, ...) instead of func directly, as was the case before version 1.7.0. Failing to do this may cause func to recurse back into __torch_function__ and therefore cause infinite recursion.