Index lookup without for loop

Hi,

I have been struggling with the following problem for a while now and would appreciate your help.

I have a for loop in my forward pass that slows down the training immensely, but I do not know how to get rid of it.
What I want to do is to resample features of an image based on optical flow. To do that, I have the following code:

# Get the indices where the flow is not zero
flow_files_non_zero_indices = torch.nonzero(flow_files_shifted, as_tuple=True)
face_images_features_resampled = face_images_features

# Iterate over these indices
for index in zip(
    flow_files_non_zero_indices[0],
    flow_files_non_zero_indices[1],
    flow_files_non_zero_indices[2],
    flow_files_non_zero_indices[3],
):
    # Get index
    n, b, u, v = index

    # Get delta_u and delta_v from flow
    delta_u = flow_files_shifted[n, b, u, v, 0].type(torch.int64)
    delta_v = flow_files_shifted[n, b, u, v, 1].type(torch.int64)

    # Index lookup
    try:
        face_images_features_resampled[n, b, u, v] = face_images_features[
            n, b, u + delta_u, v + delta_v
        ]
    except IndexError:
        face_images_features_resampled[n, b, u, v] = face_images_features[
            n, b, u, v
        ]

I know of index_put_, but I do not see how I can use it here as I need to look up delta_u and delta_v based on the current index.

Is there a way to get rid of the for loop?

Your help is very much appreciated!