Hmmm I don’t really know if it’s possible to do it in a single line.
Typically (here and in numpy) when you pass a list for indexing it means “gather elements in those indices”
For example:
a[[0,2,3],1:5,3:8]
would get [1:5,3:8] for elements 0,2,3.
In your case however it’s more like doing a for loop together with indexing.
Soo you could do something like
import numpy as np
x = np.array([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
print 'Our array is:'
print x
print '\n'
rows = np.array([[0,0],[3,3]])
cols = np.array([[0,2],[0,2]])
y = x[rows,cols]
print 'The corner elements of this array are:'
print y
Our array is:
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
The corner elements of this array are:
[[ 0 2]
[ 9 11]]
(From numpy’s examples)
And passing a matrix which get indices element-wise but it would be a bit messy compared to a for loop
import torch
N = 4
a = torch.randn(N, 3, 10, 10)
b = torch.randn(N, 3, 10, 10)
X1 = torch.randint(0, 5, (N,))
Y1 = torch.randint(0, 5, (N,))
X2 = torch.randint(6, 9, (N,))
Y2 = torch.randint(6, 9, (N,))
for i,(x1,y1,x2,y2) in enumerate(zip(X1,Y1,X2,Y2)):
b[i, :, x1:x2, y1:y2] = a[i, :, x1:x2, y1:y2]