I have a tensor t of shape (38, 38, 7, 7, 21) -> (x, y, i, j, c)

Supposing I have x_min, x_max, y_min, y_max and two values i,j, and a class num c

I want to sum all elements in the range

(x_min : xmax, y_min:y_max, i, j, c)

I am thinking something like

torch.sum(t[x_min : xmax, y_min:y_max, i:i+1, j:j+1, c:c+1]) but this would not work

if i = 6 for instance, so I would need if statements. Is there a cleaner way of achieving this?