Hi,
I have been struggling with the following problem for a while now and would appreciate your help.
I have a for loop in my forward pass that slows down the training immensely, but I do not know how to get rid of it.
What I want to do is to resample features of an image based on optical flow. To do that, I have the following code:
# Get the indices where the flow is not zero
flow_files_non_zero_indices = torch.nonzero(flow_files_shifted, as_tuple=True)
face_images_features_resampled = face_images_features
# Iterate over these indices
for index in zip(
flow_files_non_zero_indices[0],
flow_files_non_zero_indices[1],
flow_files_non_zero_indices[2],
flow_files_non_zero_indices[3],
):
# Get index
n, b, u, v = index
# Get delta_u and delta_v from flow
delta_u = flow_files_shifted[n, b, u, v, 0].type(torch.int64)
delta_v = flow_files_shifted[n, b, u, v, 1].type(torch.int64)
# Index lookup
try:
face_images_features_resampled[n, b, u, v] = face_images_features[
n, b, u + delta_u, v + delta_v
]
except IndexError:
face_images_features_resampled[n, b, u, v] = face_images_features[
n, b, u, v
]
I know of index_put_, but I do not see how I can use it here as I need to look up delta_u and delta_v based on the current index.
Is there a way to get rid of the for loop?
Your help is very much appreciated!