Cv2 remap in pytorch?

Is there a functionality close to cv2’s remap? pytorch’s grid_sample has deviating behaviors for the things I’ve tried with it.

Hi,

From quick look at cv’s doc, the two should be doing the same thing no? What kind of issues do you have with grid_sample?

I’ve tried optical flow warping to generate the next frame from the previous, but regardless of the flo file associated with the test images or one created using cv2’s farneback dense optical flow, the grid sample warped frame has visible differences.

(from a previous question of mine for the details)

If I use cv2’s remap, I get a perfect frame, so I thought their functionality might be different, but the documentation seems to read as if they’re basically the same

One notable difference maybe is that grid_sample does not take coordinates as input but values in [-1, 1] that tell you where to read in the input image.
Are you warping your grid values properly to match this?

Yep, I took care of that. Since the flow was just subtracted, the result could be normalized easily to that range without it affecting the differences between the two functions, implementation wise (at least there).

Sounds good.
And is remap also doing bilinear interpolation? Maybe that can lead to surprising results?

I’ve tested out nearest and bilinear with it, but the differences in results for both of them are negligible

1 Like

I have also same issue anyone have suggestion for this so please reply.

Good news : I believe you can now do this in pytorch 1.11.0


import torch

def remap_values(remapping, x):
    index = torch.bucketize(x.ravel(), remapping[0])
    return remapping[1][index].reshape(x.shape)


remapping = torch.arange(0, 256).cuda(), torch.randperm(256).cuda()
images_batch = torch.randint(0, 256, (16, 224, 224, 3)).cuda()
remapped_batch = remap_values(remapping, images_batch)
2 Likes

@thehappyidiot Any updates on this?

I also had the same observation recently. Basically, the output from grid_sample() looks different from what cv2.remap() produces in that the former appears “scaled” compared to the latter. Best I can describe this scaled effect is what this question has mentioned.. I see that depending on the inputs, the output from cv2.remap() leaves the invalid areas of pixel values unpopulated, whereas grid_sample() appears have scaled the image in a way that these areas do not exist.

Not sure how to replicate the behavior of cv2.remap() using grid_sample(). Maybe this difference is intended and there’s no way around it…?

I have also met the translation problem, here is my code in cv2.remap() and torch.nn.functional.grid_sample(), it is just suitable for my task.

My mission is to project the ref_img and ref_depth from a reference view to another source view.

The code in Numpy and cv2 style:

def reproject_with_depth(img_ref, depth_ref, intrinsics_ref, extrinsics_ref, intrinsics_src, extrinsics_src):
    width, height = depth_ref.shape[1], depth_ref.shape[0]

    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
    
    xyz_ref = np.matmul(np.linalg.inv(intrinsics_src),
                        np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))    
   
    xyz_src = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
                        np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
    
    K_xyz_src = np.matmul(intrinsics_ref, xyz_src)
    xy_src = K_xyz_src[:2] / K_xyz_src[2:3]

    x_src = xy_src[0].reshape([height, width]).astype(np.float32)
    y_src = xy_src[1].reshape([height, width]).astype(np.float32)
    
    sampled_depth_src = cv2.remap(depth_ref, x_src, y_src, interpolation=cv2.INTER_LINEAR)
    sampled_img_src = cv2.remap(img_ref, x_src, y_src, interpolation=cv2.INTER_LINEAR)

    return sampled_depth_src, sampled_img_src

And this is the translation into torch style:

def reproject_with_depth(img_ref, depth_ref, intrinsics_ref, extrinsics_ref, intrinsics_src, extrinsics_src):
    B, width, height = depth_ref.shape[0], depth_ref.shape[2], depth_ref.shape[1]

    y_ref, x_ref = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=depth_ref.device), torch.arange(0, width, dtype=torch.float32, device=depth_ref.device)])
    y_ref, x_ref = y_ref.contiguous(), x_ref.contiguous()
    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
    # reference 3D space
    xyz_ref = torch.matmul(torch.inverse(intrinsics_src),
                        torch.stack((x_ref, y_ref, torch.ones_like(x_ref))).unsqueeze(0).repeat(B, 1, 1) * depth_ref.reshape([B, 1, -1]))
    
    xyz_src = torch.matmul(torch.matmul(extrinsics_ref, torch.inverse(extrinsics_src)),
                        torch.cat([xyz_ref, torch.ones_like(x_ref).unsqueeze(0).repeat(B,1,1)], dim=1))[:,:3]
    
    K_xyz_src = torch.matmul(intrinsics_ref, xyz_src)
    xy_src = K_xyz_src[:, :2] / K_xyz_src[:, 2:3]

    x_src = xy_src[:, 0].reshape([B, height, width]).float()
    y_src = xy_src[:, 1].reshape([B, height, width]).float()

    grid = torch.stack((x_src/((width-1)/2)-1, y_src/((height-1)/2)-1), dim=3)
    sampled_depth_src = F.grid_sample(depth_ref.unsqueeze(1), grid.view(B, height, width, 2), mode='bilinear', padding_mode='zeros').squeeze(1)
    sampled_img_src = F.grid_sample(img_ref, grid.view(B, height, width, 2), mode='bilinear', padding_mode='zeros')

    return sampled_depth_src, sampled_img_src

The essential translation principles you should obey I think:

  • You should consider the batch channel B in torch style. And the device parameter is also the case.
  • The position sampling in the torch should be mapped into [-1, 1], hence, you should use /((width-1)/2)-1 for x-axis and /((height-1)/2)-1) for y-axis.

I hope it will be helpful to you. :slight_smile:

1 Like

Remap is implemented in Kornia

https://kornia.readthedocs.io/en/latest/geometry.transform.html#kornia.geometry.transform.remap

You also have all the utils needed for the depth stuff

Thank you very much! It’s very helpful.