Affine transformation matrix paramters conversion

Hi all,

I want to rotate an image about a specific point. First I create the Transformation matrices for moving the center point to the origin, rotating and then moving back to the first point, then apply the transform using affine_grid and grid_sample functions. But the resulting image is not what it should be. Once I tested these parameters by applying them on the image using scipy and it works.

import torch.nn.functional as F
mat_move = torch.eye(3)
mat_move[0,2] = -x_center*2/my_image.shape[2]
mat_move[1,2] = -y_center*2/my_image.shape[3]
mat_rotate = torch.eye(3)
mat_rotate[0, 0] = cos_theta[0][0]
mat_rotate[0, 1] = -sin_theta[0][0]
mat_rotate[1, 0] = sin_theta[0][0]
mat_rotate[1, 1] = cos_theta[0][0]
mat_move_back = torch.eye(3)
mat_move_back[0,2] = x_center*2/my_image.shape[2]
mat_move_back[1,2] = y_center*2/my_image.shape[3]
rigid_transform = torch.mm(mat_move_back, torch.mm(mat_rotate, mat_move))
M = Variable(torch.zeros([1, 2, 3])).cuda()
M[0, 0, 0] = rigid_transform[0, 0]
M[0, 0, 1] = rigid_transform[0, 1]
M[0, 0, 2] = rigid_transform[0, 2]
M[0, 1, 0] = rigid_transform[1, 0]
M[0, 1, 1] = rigid_transform[1, 1]
M[0, 1, 2] = rigid_transform[1, 2]
grid = F.affine_grid(M, vertebrae.size())
vertebrae = F.grid_sample(vertebrae.float(), grid)

How are you comparing the results?
It seems scipy.ndimage.affine_transform uses pixel values for the translation part, while F.affine_grid seems to want values in the range [-1, 1] (which you already provided).

This code tries to rotate and translate a line:

x = torch.eye(10).view(1, 1, 10, 10)
theta = torch.zeros(1, 2, 3)
angle = np.pi/2.
theta[:, :, :2] = torch.tensor([[np.cos(angle), -1.0*np.sin(angle)],
                                [np.sin(angle), np.cos(angle)]])
theta[:, :, 2] = 0.5

grid = F.affine_grid(theta, x.size())
x_trans = F.grid_sample(x, grid)

plt.imshow(x.squeeze().numpy())
plt.imshow(x_trans.squeeze().numpy())

Based on the values of grid the operation should work.
The visualizations however look a bit strange, but this might be due to some interpolation,
Maybe someone knows this better.

3 Likes

For example, the following code compares the results of the the same operation using pytorch and scipy. The results are not the same.

x = torch.eye(10).view(1, 1, 10, 10)
theta = torch.zeros(1, 2, 3)
angle = np.pi/2.
theta[:, :, :2] = torch.tensor([[np.cos(angle), -1.0*np.sin(angle)],
                                [np.sin(angle), np.cos(angle)]])
theta[:, :, 2] = 0.5

grid = F.affine_grid(theta, x.size())
x_trans = F.grid_sample(x, grid)

plt.imshow(x.squeeze().numpy())
plt.figure()
plt.imshow(x_trans.squeeze().numpy())

x = np.eye(10)
theta = np.zeros([2, 3])
angle = np.pi/2.

theta = np.float32([[np.cos(angle),-1.0*np.sin(angle),5],[np.sin(angle),np.cos(angle),5]]) # or 7.5 instead of 5

x_trans = scipy.ndimage.affine_transform(x, theta, order=1)

plt.figure()
plt.imshow(x)
plt.figure()
plt.imshow(x_trans)
plt.show()

Hi, I also meet this problem. I just want to know how to convert an affine transformation matrix described in scipy/skimage.transform/opencv into a right argument theta in torch.nn.functioal.affine_grid(theta,size)?

Now suppose we want to apply an affine transormation on an image with shape=(H,W,3), where

What’s the right theta which should be used in torch.nn.functioal.affine_grid(theta,size) ?

2 Likes

@ amirid
Hi, i have solved it.Code is below.

import cv2
import torch.nn.functional as F
import skimage.transform as trans
import numpy as np

def convert_image_np(inp):
“”“Convert a Tensor to numpy image.”""
inp = inp.numpy().transpose((1, 2, 0))
inp = (inp*255).astype(np.uint8)
return inp

def param2theta(param, w, h):
param = np.linalg.inv(param)
theta = np.zeros([2,3])
theta[0,0] = param[0,0]
theta[0,1] = param[0,1]*h/w
theta[0,2] = param[0,2]*2/w + param[0,0] + param[0,1] - 1
theta[1,0] = param[1,0]*w/h
theta[1,1] = param[1,1]
theta[1,2] = param[1,2]*2/h + param[1,0] + param[1,1] - 1
return theta
tr = trans.estimate_transform(‘affine’, src=src, dst=dst)
M = tr.params[0:2,:]
img = cv2.warpAffine(image,M,(w,h))
theta = param2theta(tr.params, w, h)
image = Variable(image).unsqueeze(0).cuda()
theta = Variable(theta).unsqueeze(0).cuda()
grid = F.affine_grid(theta,image.size())
img_ = F.grid_sample(image,grid)
img_ = convert_image_np(img_.data.cpu().squeeze(0))

The visualizations of img and img_ should be same. Hope it can help you!

@xiang hi, the iii I transformed using cv2.warpAffine is not always the same with x using grid_sample, any idea? thanks

    def forward(self, feature_map, boxes, mapping):
        '''

        :param feature_map:  N * 128 * 128 * 32
        :param boxes: M * 8
        :param mapping: mapping for image
        :return: N * H * W * C
        '''

        max_width = 0
        boxes_width = []
        cropped_images = []
        matrixes = []
        images = []

        for img_index, box in zip(mapping, boxes):
            feature = feature_map[img_index]  # B * H * W * C
            images.append(feature)

            x1, y1, x2, y2, x3, y3, x4, y4 = box / 4  # 521 -> 128


            # show_box(feature, box / 4, 'ffffff', isFeaturemap=True)

            rotated_rect = cv2.minAreaRect(np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]))
            box_w, box_h = rotated_rect[1][0], rotated_rect[1][1]

            width = feature.shape[2]
            height = feature.shape[1]

            if box_w <= box_h:
                box_w, box_h = box_h, box_w

            mapped_x1, mapped_y1 = (0, 0)
            mapped_x4, mapped_y4 = (0, self.height)

            width_box = math.ceil(self.height * box_w / box_h)
            max_width = width_box if width_box > max_width else max_width

            mapped_x2, mapped_y2 = (width_box, 0)

            # affine_matrix = cv2.getAffineTransform(np.float32([(x1, y1), (x2, y2), (x4, y4)]), np.float32([
            #     (mapped_x1, mapped_y1), (mapped_x2, mapped_y2), (mapped_x4, mapped_y4)
            # ]))

            affine_matrix = trans.estimate_transform('affine', np.float32([(x1, y1), (x2, y2), (x4, y4)]), np.float32([
                (mapped_x1, mapped_y1), (mapped_x2, mapped_y2), (mapped_x4, mapped_y4)
            ]))

            affine_matrix = affine_matrix.params[0:2, :]
            iii = cv2.warpAffine(feature.permute(1,2,0).cpu().numpy().astype(np.uint8),
                                 affine_matrix, (width, height))

            cv2.imshow('img', iii)
            cv2.waitKey()

            affine_matrix = self.param2theta(affine_matrix, width, height)

            grid = torch.nn.functional.affine_grid(torch.tensor(affine_matrix[np.newaxis]), feature[np.newaxis].size())
            x = torch.nn.functional.grid_sample(feature[np.newaxis], grid)
            x = x[0].permute(1, 2, 0).detach().cpu().numpy()
            x = x.astype(np.uint8)

            cv2.imshow('img', x)
            cv2.waitKey()

            matrixes.append(torch.tensor(affine_matrix, device=feature.device))
            boxes_width.append(width_box)

        matrixes = torch.stack(matrixes)
        images = torch.stack(images)
        grid = nn.functional.affine_grid(matrixes, images.size())
        feature_rotated = nn.functional.grid_sample(images, grid)


        channels = feature_rotated.shape[1]
        cropped_images_padded = torch.zeros((len(feature_rotated), channels, self.height, max_width),
                                            dtype=feature_rotated.dtype,
                                            device=feature_rotated.device)

Hi, i can not upload the code now. qq 6553947

In

def param2theta(param, w, h):
    param = np.linalg.inv(param)

How do you invert a non square matrix?

just append [0,0,1] horizontally

this is wrong

def param2theta(param, w, h):
        param = np.linalg.inv(param)
        theta = np.zeros([2,3])
        theta[0,0] = param[0,0]
        theta[0,1] = param[0,1]*h/w
        theta[0,2] = param[0,2]*2/w + param[0,0] + param[0,1] - 1
        theta[1,0] = param[1,0]*w/h
        theta[1,1] = param[1,1]
        theta[1,2] = param[1,2]*2/h + param[1,0] + param[1,1] - 1
        return theta

below is correct

def param2theta(param, w, h):
        param = np.linalg.inv(param)
        theta = np.zeros([2,3])
        theta[0,0] = param[0,0]
        theta[0,1] = param[0,1]*h/w
        theta[0,2] = param[0,2]*2/w + theta[0,0] + theta[0,1] - 1
        theta[1,0] = param[1,0]*w/h
        theta[1,1] = param[1,1]
        theta[1,2] = param[1,2]*2/h + theta[1,0] + theta[1,1] - 1
        return theta
3 Likes

@jiangxiluning
Yes. I have already solved this problem with the same method. But I don’t have further idea why it works. Could you share your explanations?

The param2theta function corresponds to this equation
eq ,
where you resize, so that the width and height equate to 2 and shift by -1, to put it in the range of [-1, 1] for each dimension.

see
- Wolfram|Alpha{{2/w, 0, -1}, {0, 2/h, -1}, {0, 0, 1}} * {{a, b, c}, {d, e, f}, {0, 0, 1}} * {{2/w, 0, -1}, {0, 2/h, -1}, {0, 0, 1}}^-1

4 Likes

Demo code can be found,https://github.com/wuneng/WarpAffine2GridSample.

2 Likes

hi, @Budel could you explain how do you get the equation?:joy:

to mimic cv2.warpAffine in pytorch you can use kornia’s version: https://kornia.readthedocs.io/en/latest/geometry.transform.html#kornia.geometry.transform.warp_affine

3 Likes

@jiangxiluning actually, it is not difficult to derive, as long as you know pytorch uses normalized coordinates, which can be learn from other discussion in this forum, below actually shows theta to param, the inverse matrix will be param to theta

5 Likes

Hi, great code! I went through the various steps and managed to get to the final transformation matrix as @Budel and @xpngzhng did, and according to my mathematical calculations I shouldn’t do the inverse of “param” before. Obviously applying the transformation without the inverse I don’t get the correct result as yours… Could you please explain to me why it should be done? Why do you invert the matrix before making the transformation (change of basis)?

Okay, so to drive this point completely home, here is how you convert front and back between opencv and torch.F.affine_grid


def get_N(W, H):
    """N that maps from unnormalized to normalized coordinates"""
    N = np.zeros((3, 3), dtype=np.float64)
    N[0, 0] = 2.0 / W
    N[0, 1] = 0
    N[1, 1] = 2.0 / H
    N[1, 0] = 0
    N[0, -1] = -1.0
    N[1, -1] = -1.0
    N[-1, -1] = 1.0
    return N


def get_N_inv(W, H):
    """N that maps from normalized to unnormalized coordinates"""
    # TODO: do this analytically maybe?
    N = get_N(W, H)
    return np.linalg.inv(N)


def cvt_MToTheta(M, w, h):
    """convert affine warp matrix `M` compatible with `opencv.warpAffine` to `theta` matrix
    compatible with `torch.F.affine_grid`

    Parameters
    ----------
    M : np.ndarray
        affine warp matrix shaped [2, 3]
    w : int
        width of image
    h : int
        height of image

    Returns
    -------
    np.ndarray
        theta tensor for `torch.F.affine_grid`, shaped [2, 3]
    """
    M_aug = np.concatenate([M, np.zeros((1, 3))], axis=0)
    M_aug[-1, -1] = 1.0
    N = get_N(w, h)
    N_inv = get_N_inv(w, h)
    theta = N @ M_aug @ N_inv
    theta = np.linalg.inv(theta)
    return theta[:2, :]


def cvt_ThetaToM(theta, w, h, return_inv=False):
    """convert theta matrix compatible with `torch.F.affine_grid` to affine warp matrix `M`
    compatible with `opencv.warpAffine`.

    Note:
    M works with `opencv.warpAffine`.
    To transform a set of bounding box corner points using `opencv.perspectiveTransform`, M^-1 is required

    Parameters
    ----------
    theta : np.ndarray
        theta tensor for `torch.F.affine_grid`, shaped [2, 3]
    w : int
        width of image
    h : int
        height of image
    return_inv : False
        return M^-1 instead of M.

    Returns
    -------
    np.ndarray
        affine warp matrix `M` shaped [2, 3]
    """
    theta_aug = np.concatenate([theta, np.zeros((1, 3))], axis=0)
    theta_aug[-1, -1] = 1.0
    N = get_N(w, h)
    N_inv = get_N_inv(w, h)
    M = np.linalg.inv(theta_aug)
    M = N_inv @ M @ N
    if return_inv:
        M_inv = np.linalg.inv(M)
        return M_inv[:2, :]
    return M[:2, :]

# small test
theta = cvt_MToTheta(M, w, h)
M2 = cvt_ThetaToM(theta, w, h)
assert np.allclose(M, M2)


I will also make a PR to the example proposed @xiang [here]

1 Like

Hi @ptrblck,

how can I obtain the new coordinates of the corners of the transformed image, especially if I scale and rotate the original image.

Best Matthias

If you are using an affine transformation (e.g. by applying a transformation matrix), you could apply this transformation to the original edge coordinates to get the transformed coordinates.

1 Like