Thanks for the link.
But, could I use this with different crop-sizes?
input – input batch of images (N x C x H_i x W_i)
grid – flow-field of size (N x H_o x W_o x 2)
Is it only for the fixed crop-size? Or in other words, only for bs=1?
# copy some codes from
# https://discuss.pytorch.org/t/cropping-a-minibatch-of-images-each-image-a-bit-differently/12247/5
###############
import cv2
import torch
import numpy as np
#############
def build_grid(bs, source_hgh,source_wid, target_y0, target_x0, target_hgh, target_wid):
grid_h = (torch.linspace(target_y0[0], target_y0[0]+target_hgh, steps = target_hgh)*2.0/source_hgh-1.0).unsqueeze(-1).repeat(1,target_wid).unsqueeze(-1)
grid_w = (torch.linspace(target_x0[0], target_x0[0]+target_wid, steps = target_wid)*2.0/source_wid-1.0).unsqueeze(0).repeat(target_hgh,1).unsqueeze(-1)
grid = torch.cat([grid_w,grid_h],dim=2).unsqueeze(0)
if bs > 1:
for i in range(1, bs):
grid_h = (torch.linspace(target_y0[i], target_y0[i]+target_hgh, steps = target_hgh)*2.0/source_hgh-1.0).unsqueeze(-1).repeat(1,target_wid).unsqueeze(-1)
grid_w = (torch.linspace(target_x0[i], target_x0[i]+target_wid, steps = target_wid)*2.0/source_wid-1.0).unsqueeze(0).repeat(target_hgh,1).unsqueeze(-1)
wh = torch.cat([grid_w, grid_h], dim = 2).unsqueeze(0)
grid = torch.cat([grid, wh], dim = 0)
return grid #grid.cuda()
##############
bs = 2
source_hgh = 100
source_wid = 120
target_hgh = 50
target_wid = 60
target_y0 = torch.tensor([0, 20])
target_x0 = torch.tensor([10, 0])
##############
img0 = cv2.imread('1.jpg')
img1 = cv2.imread('2.jpg')
m0 = img0[:source_hgh, :source_wid, :]
m1 = img1[:source_hgh, :source_wid, :]
cv2.imwrite('m0.jpg', m0)
cv2.imwrite('m1.jpg', m1)
n0 = torch.from_numpy(np.transpose(m0, (2, 0,1)))
n1 = torch.from_numpy(np.transpose(m1, (2, 0,1)))
t = torch.zeros(bs, 3, source_hgh, source_wid)
t[0] = n0
t[1] = n1
grid = build_grid(bs, source_hgh,source_wid, target_y0, target_x0, target_hgh, target_wid )
crp = torch.nn.functional.grid_sample(t, grid)
r0 = crp[0].numpy()
r1 = crp[1].numpy()
r0 = np.transpose(r0, (1,2,0))
r1 = np.transpose(r1, (1,2,0))
cv2.imwrite('r0.jpg', r0)
cv2.imwrite('r1.jpg', r1)
#######
OK, I see. With F.affine_grid(theta, size), the bboxes could have different sizes and locations. But the outputs would be resized.
import cv2
import torch
import numpy as np
##############
bs=2
source_width=100
source_height=120
output_width=70
output_height=60
theta=torch.zeros(bs,2,3)
target_y0=torch.tensor([0, 20],dtype=torch.float)
target_x0=torch.tensor([10, 0],dtype=torch.float)
target_y1=torch.tensor([60, 80],dtype=torch.float)
target_x1=torch.tensor([90, 60],dtype=torch.float)
theta[:, 0, 0] = (target_x1 - target_x0) / (source_width - 1)
theta[:, 0 ,2] = (target_x1 + target_x0 - source_width + 1) / (source_width - 1)
theta[:, 1, 1] = (target_y1 - target_y0) / (source_height - 1)
theta[:, 1, 2] = (target_y1 + target_y0 - source_height + 1) / (source_height - 1)
grid= torch.nn.functional.affine_grid(theta, (bs,3,output_height,output_width ))
###############
img0=cv2.imread('1.png')
img1=cv2.imread('2.png')
m0=img0[:source_height, :source_width, :]
m1=img1[:source_height, :source_width, :]
cv2.imwrite('m0.jpg', m0)
cv2.imwrite('m1.jpg', m1)
n1=torch.from_numpy(np.transpose(m1,(2, 0,1)))
n0=torch.from_numpy(np.transpose(m0,(2, 0,1)))
t=torch.zeros(bs, 3, source_height, source_width)
t[0]=n0
t[1]=n1
##############
crp=torch.nn.functional.grid_sample(t, grid)
r0=crp[0].numpy()
r1=crp[1].numpy()
r0=np.transpose(r0,(1,2,0))
r1=np.transpose(r1,(1,2,0))
cv2.imwrite('r0.jpg', r0)
cv2.imwrite('r1.jpg', r1)