How do I use grid_sample to translate a frames pixels using a Flow Field?

I have used the RAFT model in pytorch to calculate optical flow between two frames. Here is the code for that:

    noise_img = noise_img.to(device)
    clean_img = clean_img.to(device)
    return_index = noise_img.size(1) // 2
    aligned_frames = torch.zeros((noise_img.size(0), noise_img.size(1), noise_img.size(2), noise_img.size(3), noise_img.size(4)))
    aligned_frames[:, return_index, :, :, :] = noise_img[:, return_index, :, :, :]

    for idx in range(noise_img.size(1)):
            if not idx == return_index:
                    curr_frame = noise_img[:, idx, :, :, :]
                    ref_frame = noise_img[:, return_index, :, :, :]
                    curr_transf, ref_transf = transforms(curr_frame, ref_frame)
                    curr_flow = mc_model(curr_transf, ref_transf)[-1] # Take the final flow prediction
                    aligned_frames[:, idx, :, :, :] = align_frames(curr_transf, curr_flow)

In the above I am passing two frames through mc_model (RAFT) to return an optical flow map. In the final line I am trying to map the current frame to be aligned with the reference frame. Below is the function I use:

def warp_flow(img, flow):
    flow_permute = torch.permute(flow, (0, 2, 3, 1))
    remapped = torch.nn.functional.grid_sample(img, flow_permute)
    return remapped

Unfortunately, remapped when saved as an image, does not return a coherent image. Most images are zero with some looking like bright waves. I’m missing a step in using curr_flow but I don’t quite understand what.

Thank you.