Understanding __torch_function__

Hi folks,

I am subclassing torch.Tensor, and am struggling to understand __torch_function__, and specifically its types argument.

  • What are expected inputs?
  • And what is the intended use?

I did read through this page, where it says:

The __torch_function__ method takes four arguments: func, a reference to the torch API function that is being overridden, types, the list of types of Tensor-likes that implement __torch_function__, args, the tuple of arguments passed to the function, and kwargs, the dict of keyword arguments passed to the function.

Unfortunately, I can’t gain much from that explanation. Where do these types come from, and what are they the types of?

From playing around a little, I observed that types always seems to be a 1-tuple containing the type of my subclass, never anything else.

Motivation for this question:
I have encountered a bug in my project, and the following seems to fix it (in my subclass):

def __torch_function__(cls, func, types, args=(), kwargs=None):
    return super().__torch_function__(func, (torch.Tensor,), args, kwargs)

But not understanding __torch_function__ and its types argument properly, I am not confident about how good of an idea this is.

Any help appreciated, thanks in advance!

1 Like

I am also interested in understanding what the types argument represent. Thanks in advance :grinning:

In case someone else comes across this, I took some time to write down what I learned since asking this question here