Some safe indexing for tensor?

Hi !

let’s say I have a tensor v of shape (dim_m,dim_n,1,dim_q). The tensor is super big.

For some reason, I want to support accessing it with v[m,n,p,q] where p can actually go beyond 0.

In my application, the desired behaviour is to basically ignore p, so that I should return v[m,n,0,q].
This means that indexes out of bounds should clamp to the latest available entry on that dimension.

  • An easy solution would of course be to duplicate v along the p dimension, but I don’t want to do that for speed and also because v is a huge tensor
  • I can’t modify the fact that I will ask for some p that go beyond 0. (I isolated the core question, the fact that I want to index there is a consequence of using some complicated torch.gather stuff)

In essence, I want some “safe indexing” tensor. does that exist ? Can I overload the torch.Tensor class to accomplish this without breaking everything ?

As you said, you could ignore p.

But if you want to clamp your inputs you could use torch.clamp documentation

# Example
dim_m = 10
dim_n = 20
p = 0
dim_q = 30

min_ = torch.zeros(4)
max_ = torch.tensor([dim_m, dim_n, p, dim_q])

inp = torch.randint(-1, 50, (4,), dtype=torch.float)

print(inp)

clamped_inp = torch.clamp(inp, min_, max_)
print(clamped_inp)

# Output:
# inp:     tensor([ 9., 48., 15., 33.])
# clamped: tensor([ 9., 20.,  0., 30.])

Thanks for your answer, but this doesn’t answer my question, because I basically cannot influence with which indices the tensor v will be accessed. This means that v will be accessed with your inp variable. I need to address my problem elsewhere, e.g. by deriving a SafeAccessTensor class

Well I would limit the access to the user with a method.
Inside it check if the input is valid and clamp it if not.

Something like:

def get_v(idx_m, idx_n, idx_p, idx_q):
    # clamp inputs
    # select v with valid indices
    return v

But maybe someone else has a better solution that actually solves your problem.

thanks again a million for the time spent for me =)

Unfortunately, I can’t do this thing because the indexing is done inside another function that is internal to torch (torch.gather in this case)

1 Like

I believe just subclassing torch.Tensor might do the trick. Something like this:

# suppose you want to "virtually expand" dim 2
class MyTensor(torch.Tensor):
    def __getitem__(self, index):
        index = list(index)
        index[2] = 0
        return super().__getitem__(tuple(index))

tensor = torch.randn(3, 4, 1, 6)
mytensor = MyTensor(tensor)

print(mytensor[0, 0, 0, 0], mytensor[0, 0, 1, 0], mytensor[0, 0, 2, 0])  # works and is identical, as desired
print(mytensor[0, 0, 0, 0], mytensor[1, 0, 0, 0], mytensor[2, 0, 0, 0])  # is different

Output:
tensor(0.1178) tensor(0.1178) tensor(0.1178)
tensor(0.1178) tensor(1.3803) tensor(-1.2770)

This mock class will fail if you want to also access the tensor via a slice, such as mytensor[0, 0, :, 0]. You have to fiddle around and make that work by adding a bit of logic around:

if type(index[2]) == slice:

But it shouldn’t be too bad!

1 Like

This indeed answers the question, although it turns out that the function I am using is apparently implemented in a way that __getitem__ is not called. Still, this requires a specific question and is out of the scope of that one, which you answered nicely. Thanks for this !