I have used the RAFT model in pytorch to calculate optical flow between two frames. Here is the code for that:
noise_img = noise_img.to(device) clean_img = clean_img.to(device) return_index = noise_img.size(1) // 2 aligned_frames = torch.zeros((noise_img.size(0), noise_img.size(1), noise_img.size(2), noise_img.size(3), noise_img.size(4))) aligned_frames[:, return_index, :, :, :] = noise_img[:, return_index, :, :, :] for idx in range(noise_img.size(1)): if not idx == return_index: curr_frame = noise_img[:, idx, :, :, :] ref_frame = noise_img[:, return_index, :, :, :] curr_transf, ref_transf = transforms(curr_frame, ref_frame) curr_flow = mc_model(curr_transf, ref_transf)[-1] # Take the final flow prediction aligned_frames[:, idx, :, :, :] = align_frames(curr_transf, curr_flow)
In the above I am passing two frames through mc_model (RAFT) to return an optical flow map. In the final line I am trying to map the current frame to be aligned with the reference frame. Below is the function I use:
def warp_flow(img, flow): flow_permute = torch.permute(flow, (0, 2, 3, 1)) remapped = torch.nn.functional.grid_sample(img, flow_permute) return remapped
remapped when saved as an image, does not return a coherent image. Most images are zero with some looking like bright waves. I’m missing a step in using
curr_flow but I don’t quite understand what.