Questions about RAFT API

Hi, I have been confused about the API of the torchvision implementation of RAFT (thank you for the convenient method, btw) and have a couple of questions regarding that.

My first and biggest question is about the reference image of img1 and img2 in this example: which is the reference frame at which the predicted flow is? (attention the definition of “reference” might be different.)
In other words, could you provide the documentation about the actual direction of the flow (img1 → img2 or img2 → img1; I guess the latter is the correct one).

Roughly speaking we get flow = predict_raft(img1, img2). I conducted a sanity check by warping one of the input image with the flow and checking if the warped_image looks similar to the other input image. My warp function is as attached at the end of this post (in case my warp implementation is wrong).

And when I check both direction, comparing img1 with the warped_img = warp(img2, flow) looks more reasonable than comparing img2 with warped_img = warp(img1, flow).

I found that some people have been confused too (e.g., How to use it to calculate the optical flow · Issue #158 · princeton-vl/RAFT · GitHub please use the browser translation).

And a related question is can you provide the color wheel (the legend of the optical flow) of torchvision.utils.flow_to_image?

def warp(x, flo):
    warp an image tensor according to the optical flow
    x: [B, C, H, W] (image)
    flo: [B, 2, H, W] flow
    B, C, H, W = x.size()
    # mesh grid
    xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
    yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
    xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
    yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
    grid =, yy), 1).float()
    if x.is_cuda:
        grid = grid.cuda()
    vgrid = grid + flo
    # scale grid to [-1,1] for grid_sample
    vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
    vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
    vgrid = vgrid.permute(0, 2, 3, 1)
    output = F.grid_sample(x, vgrid, align_corners=True)
    return output