Hi, actually I find a way to select sub matrices in vectorize way.
asumme you have a start_x vector of different indexes in size (batch,) and you have a start_y vector of different indexes in size (batch,).
now I use the idea of numpy advance indexing from here: numpy doc
first create a map of all the indexing I interested in with size (batch_size, step_size, step_size)
(in my example I use a window around the center that I detect with argmax)
start_x = target_mat_inx[:, 0] - nn_cfg.window_size_around_center # shape(batch_size,)
start_y = target_mat_inx[:, 1] - nn_cfg.window_size_around_center # shape(batch_size,)
start_x_mat = (start_x.reshape(-1, 1) + torch.arange(2*nn_cfg.window_size_around_center, device=device)).unsqueeze(-1).repeat(1, 1, 2*nn_cfg.window_size_around_center) # shape(batch_size, 2* step_size, 2* step_size)
start_y_mat = (start_y.reshape(-1, 1) + torch.arange(2*nn_cfg.window_size_around_center, device=device)).unsqueeze(1).repeat(1, 2*nn_cfg.window_size_around_center, 1) # shape(batch_size, 2* step_size, 2* step_size)
now, create the batch selection vector:
batch_index = torch.arange(nn_cfg.batch_size, device=device).reshape(-1,1,1)
select the sub matrixes in all batch
matrix_mask_target = matrix[batch_index, start_x_mat, start_y_mat] # shape(batch_size, 2* step_size, 2* step_size)
note: it is possible to select sub bathes as well, just change the batch_index to specific selection and make sure that start_x and start_y align with your selection