I am subclassing
torch.Tensor, and am struggling to understand
__torch_function__, and specifically its
- What are expected inputs?
- And what is the intended use?
I did read through this page, where it says:
__torch_function__method takes four arguments:
func, a reference to the torch API function that is being overridden,
types, the list of types of Tensor-likes that implement
args, the tuple of arguments passed to the function, and
kwargs, the dict of keyword arguments passed to the function.
Unfortunately, I can’t gain much from that explanation. Where do these types come from, and what are they the types of?
From playing around a little, I observed that
types always seems to be a 1-tuple containing the type of my subclass, never anything else.
Motivation for this question:
I have encountered a bug in my project, and the following seems to fix it (in my subclass):
@classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return super().__torch_function__(func, (torch.Tensor,), args, kwargs)
But not understanding
__torch_function__ and its
types argument properly, I am not confident about how good of an idea this is.
Any help appreciated, thanks in advance!