How to index 4D tensors with 2D tensors?

Hi!
I have a 4D tensor A with shape (B, X, Y, Z) and a tensor B with shape (N, 4), each of whose rows stores the index of one element of tensor A.
For example:

B = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 1]])
# I want to access tensor A:
A[0, 0, 0, 0] = 1
A[0, 0, 0, 1] = 1

How can I access tensor A using tensor B without for loop? I find it difficult to use scatter, index_select, and index_put.

Thank you!

If Iā€™m understanding you correctly, you should be able to use a long tensor.

b = torch.Tensor([[0, 0, 0, 0], [0, 0, 0, 1]]).type(torch.long)
a[b] = 1

Thank you for your reply.
I tried your method but get:

RuntimeError: CUDA out of memory

I think your solution will select an entire row rather than one single element.

a = torch.arange(8).view(2, 4)
a
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

b = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 1]]).type(torch.long)
a[b]
tensor([[[0, 1, 2, 3],
         [0, 1, 2, 3],
         [0, 1, 2, 3],
         [0, 1, 2, 3]],

        [[0, 1, 2, 3],
         [0, 1, 2, 3],
         [0, 1, 2, 3],
         [4, 5, 6, 7]]])

Ah yeah whoops, not sure if there is any way of doing that via indexing unfortunately

You could do something like this.

a = torch.arange(8).view(1, 1, 2, 4)
b = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 3]])

print(a[b[:, 0], b[:, 1], b[:, 2], b[:, 3]])
# Output:
tensor([0, 1, 7])

Thank you. Your method is simple and cool. :smiley: I figure out a method using index_put_:

a = torch.arange(8).view(1, 1, 2, 4)
b = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 3]])
inds = torch.chunk(b, 4, dim=1)
a.index_put_(inds, torch.tensor(0))

Not sure which one runs faster. I test them with pytorch1.9:

import torch
import time

a = torch.arange(468 * 468, device='cuda:0').view(1, 1, 468, 468)
b = torch.randint(0, 468, (20000, 4), device='cuda:0')
b[:, 0] = b[:, 1] = 0

t1 = time.time()
for i in range(10000):
    a[b[:, 0], b[:, 1], b[:, 2], b[:, 3]] = 1
print(f"Time: {time.time() - t1}") 

t2 = time.time()
for i in range(10000):
    inds = torch.chunk(b, 4, dim=1)
    a.index_put_(inds, torch.tensor(1).to(b.device))
print(f"Time: {time.time() - t2}")
# Output:
Time: 0.5770895481109619
Time: 0.7494716644287109

Using index_put_ is slower.
And index_put_ shows lower speed even when not considering the overhead of torch.chunk. (0.41 vs 0.44)

1 Like

I am sorry. Index_put_ seems to be much faster.
Torch.chunk and torch.tensor(1).to(b.device) cause the slowness.

import torch
import time

a = torch.arange(468 * 468, device='cuda:0').view(1, 1, 468, 468)
b = torch.randint(0, 468, (20000, 4), device='cuda:0')
b[:, 0] = b[:, 1] = 0

t1 = time.time()
for i in range(10000):
    a[b[:, 0], b[:, 1], b[:, 2], b[:, 3]] = 1
print(f"Time: {time.time() - t1}") 

inds = torch.chunk(b, 4, dim=1)
value = torch.tensor(1).to(b.device)
t2 = time.time()
for i in range(10000):
    a.index_put_(inds, value)
print(f"Time: {time.time() - t2}")
# Output:
Time: 0.38867831230163574
Time: 0.13382434844970703
1 Like