Hi folks,
I am subclassing torch.Tensor
, and am struggling to understand __torch_function__
, and specifically its types
argument.
- What are expected inputs?
- And what is the intended use?
I did read through this page, where it says:
The
__torch_function__
method takes four arguments:func
, a reference to the torch API function that is being overridden,types
, the list of types of Tensor-likes that implement__torch_function__
,args
, the tuple of arguments passed to the function, andkwargs
, the dict of keyword arguments passed to the function.
Unfortunately, I can’t gain much from that explanation. Where do these types come from, and what are they the types of?
From playing around a little, I observed that types
always seems to be a 1-tuple containing the type of my subclass, never anything else.
Motivation for this question:
I have encountered a bug in my project, and the following seems to fix it (in my subclass):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, (torch.Tensor,), args, kwargs)
But not understanding __torch_function__
and its types
argument properly, I am not confident about how good of an idea this is.
Any help appreciated, thanks in advance!