I am sorry. Index_put_ seems to be much faster.
Torch.chunk and torch.tensor(1).to(b.device) cause the slowness.
import torch
import time
a = torch.arange(468 * 468, device='cuda:0').view(1, 1, 468, 468)
b = torch.randint(0, 468, (20000, 4), device='cuda:0')
b[:, 0] = b[:, 1] = 0
t1 = time.time()
for i in range(10000):
a[b[:, 0], b[:, 1], b[:, 2], b[:, 3]] = 1
print(f"Time: {time.time() - t1}")
inds = torch.chunk(b, 4, dim=1)
value = torch.tensor(1).to(b.device)
t2 = time.time()
for i in range(10000):
a.index_put_(inds, value)
print(f"Time: {time.time() - t2}")
# Output:
Time: 0.38867831230163574
Time: 0.13382434844970703