let’s say I have a tensor v of shape (dim_m,dim_n,1,dim_q). The tensor is super big.
For some reason, I want to support accessing it with v[m,n,p,q] where p can actually go beyond 0.
In my application, the desired behaviour is to basically ignore p, so that I should return v[m,n,0,q].
This means that indexes out of bounds should clamp to the latest available entry on that dimension.
An easy solution would of course be to duplicate v along the p dimension, but I don’t want to do that for speed and also because v is a huge tensor
I can’t modify the fact that I will ask for some p that go beyond 0. (I isolated the core question, the fact that I want to index there is a consequence of using some complicated torch.gather stuff)
In essence, I want some “safe indexing” tensor. does that exist ? Can I overload the torch.Tensor class to accomplish this without breaking everything ?
Thanks for your answer, but this doesn’t answer my question, because I basically cannot influence with which indices the tensor v will be accessed. This means that v will be accessed with your inp variable. I need to address my problem elsewhere, e.g. by deriving a SafeAccessTensor class
I believe just subclassing torch.Tensor might do the trick. Something like this:
# suppose you want to "virtually expand" dim 2
class MyTensor(torch.Tensor):
def __getitem__(self, index):
index = list(index)
index[2] = 0
return super().__getitem__(tuple(index))
tensor = torch.randn(3, 4, 1, 6)
mytensor = MyTensor(tensor)
print(mytensor[0, 0, 0, 0], mytensor[0, 0, 1, 0], mytensor[0, 0, 2, 0]) # works and is identical, as desired
print(mytensor[0, 0, 0, 0], mytensor[1, 0, 0, 0], mytensor[2, 0, 0, 0]) # is different
Output:
tensor(0.1178) tensor(0.1178) tensor(0.1178)
tensor(0.1178) tensor(1.3803) tensor(-1.2770)
This mock class will fail if you want to also access the tensor via a slice, such as mytensor[0, 0, :, 0]. You have to fiddle around and make that work by adding a bit of logic around:
This indeed answers the question, although it turns out that the function I am using is apparently implemented in a way that __getitem__ is not called. Still, this requires a specific question and is out of the scope of that one, which you answered nicely. Thanks for this !