I have been struggling with the following problem for a while now and would appreciate your help.
I have a for loop in my forward pass that slows down the training immensely, but I do not know how to get rid of it.
What I want to do is to resample features of an image based on optical flow. To do that, I have the following code:
# Get the indices where the flow is not zero flow_files_non_zero_indices = torch.nonzero(flow_files_shifted, as_tuple=True) face_images_features_resampled = face_images_features # Iterate over these indices for index in zip( flow_files_non_zero_indices, flow_files_non_zero_indices, flow_files_non_zero_indices, flow_files_non_zero_indices, ): # Get index n, b, u, v = index # Get delta_u and delta_v from flow delta_u = flow_files_shifted[n, b, u, v, 0].type(torch.int64) delta_v = flow_files_shifted[n, b, u, v, 1].type(torch.int64) # Index lookup try: face_images_features_resampled[n, b, u, v] = face_images_features[ n, b, u + delta_u, v + delta_v ] except IndexError: face_images_features_resampled[n, b, u, v] = face_images_features[ n, b, u, v ]
I know of index_put_, but I do not see how I can use it here as I need to look up delta_u and delta_v based on the current index.
Is there a way to get rid of the for loop?
Your help is very much appreciated!