I have used the RAFT model in pytorch to calculate optical flow between two frames. Here is the code for that:
noise_img = noise_img.to(device)
clean_img = clean_img.to(device)
return_index = noise_img.size(1) // 2
aligned_frames = torch.zeros((noise_img.size(0), noise_img.size(1), noise_img.size(2), noise_img.size(3), noise_img.size(4)))
aligned_frames[:, return_index, :, :, :] = noise_img[:, return_index, :, :, :]
for idx in range(noise_img.size(1)):
if not idx == return_index:
curr_frame = noise_img[:, idx, :, :, :]
ref_frame = noise_img[:, return_index, :, :, :]
curr_transf, ref_transf = transforms(curr_frame, ref_frame)
curr_flow = mc_model(curr_transf, ref_transf)[-1] # Take the final flow prediction
aligned_frames[:, idx, :, :, :] = align_frames(curr_transf, curr_flow)
In the above I am passing two frames through mc_model (RAFT) to return an optical flow map. In the final line I am trying to map the current frame to be aligned with the reference frame. Below is the function I use:
def warp_flow(img, flow):
flow_permute = torch.permute(flow, (0, 2, 3, 1))
remapped = torch.nn.functional.grid_sample(img, flow_permute)
return remapped
Unfortunately, remapped
when saved as an image, does not return a coherent image. Most images are zero with some looking like bright waves. I’m missing a step in using curr_flow
but I don’t quite understand what.
Thank you.