I’m trying to make a tensor subclass that has an extra instance attribute to store data, and I need to store multiple of these subclass objects in a tensor. When I use torch.stack to store them, the elements of the stacked tensor no longer have the extra attribute that they were assigned.
Do you know how we can store multiple of these subclass objects in a tensor (or in a subclass of a tensor) while preserving the extra attributes such that I can access them as stacked[0].extra?
The core problem is that a Tensor is in no sense a collection (e.g., a list)
of Tensors (nor a collection of other objects). You can’t store a Tensor (nor
another object) “in” a Tensor.
A regular Tensor object does not have an extra attribute. But furthermore,
your tensorExtra object does not have per-element extra attributes – it
only carries a single extra attribute for the entire tensorExtra object as a
whole.
And to reiterate what I said above, neither a Tensor nor a tensorExtra is
a collection, so they can’t contain multiple tensorExtras, each with its own extra attribute.
To belabor the point, you can’t, because a Tensor is not a collection in which
you can store multiple objects. (I doubt it fits the use case you have in mind,
but you certainly can store multiple tensorExtras in, say, a list, and if you
do, their extra attributes will be preserved and be accessible.)
As an aside, here are some points I do not understand:
As far as I know, torch.stack() is implemented in c++ (with a python wrapper).
Your tensorExtra is a python class about which c++ knows nothing directly.
However, torch.stack() somehow knows – presumably through the python
wrapper machinery – to return a tensorExtra when called with a tensorExtra
(albeit a “broken” tensorExtra that lacks an extra attribute).
So, roughly speaking, torch.stack() is trying to implement a covariant return
type when passed* the Tensor subclass tensorExtra. I’m not sure such a thing
is sound, because (without being overridden) torch.stack()can’t know how to
return a tensorExtra, as evidenced by the fact that the best it can do is to return
a broken tensorExtra that lacks the extra attribute.
I have no idea how or why torch.stack() returns a tensorExtra when a tensorExtra is passed in.
*) Of course, the argument to torch.stack() is not a tensorExtra (nor a plain Tensor). Rather, the argument is a (python) list. Now a python list is not a
parameterized type, but to further complicate question of torch.stack()'s return
type, at a conceptual level, should a hypothetical list_of_tensorExtras be a
subtype of list_of_Tensors, be a supertype thereof, be both, or be neither?
That makes sense. I was getting thrown off by what you mentioned later, where torch.stack was returning a tensorExtra object and, for each “element” in this, type(element) was returning tensorExtra.
I’m not sure either. I’ve tried testing this out and editing the class definition, but it somehow returns a tensorExtra with the “broken” extra attribute, which is strange, since the __new__ function explicitly requires an extra attribute to be defined for a tensorExtra to be instantiated.