I’m wandering how to implement bilinear upsample(align_corners=True).
The code below is my attemption, and the mse is 0.0245.
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
def bilinear(img, upscale, align_corners=False):
B, c, h, w = img.shape
new_h, new_w = h * upscale, w * upscale
if not align_corners:
x = torch.clip((torch.arange(new_h, dtype=torch.double, device=img.device) + 0.5) * h / new_h - 0.5, 0, h - 1)
y = torch.clip((torch.arange(new_w, dtype=torch.double, device=img.device) + 0.5) * w / new_w - 0.5, 0, w - 1)
x = x[:, None].expand(-1, new_w)
y = y[None, :].expand(new_h, -1)
x1 = torch.clip(torch.floor(x), 0, h - 2).long()
y1 = torch.clip(torch.floor(y), 0, w - 2).long()
x2 = x1 + 1
y2 = y1 + 1
else:
x = torch.clip(torch.arange(new_h, dtype=torch.double, device=img.device) * (h - 1) / (new_h - 1), 0, h - 1)
y = torch.clip(torch.arange(new_w, dtype=torch.double, device=img.device) * (w - 1) / (new_w - 1), 0, w - 1)
x = x[:, None].expand(-1, new_w)
y = y[None, :].expand(new_h, -1)
x1 = torch.floor(x).long()
y1 = torch.floor(y).long()
x2 = torch.ceil(x).long()
y2 = torch.ceil(y).long()
q11 = img[..., x1, y1]
q12 = img[..., x1, y2]
q21 = img[..., x2, y1]
q22 = img[..., x2, y2]
u = (x2 - x)[None, None, ...]
v = (y2 - y)[None, None, ...]
_1_u = 1 - u
_1_v = 1 - v
w11 = u * v
w12 = u * _1_v
w21 = _1_u * v
w22 = _1_u * _1_v
result = q11 * w11 + q12 * w12 + q21 * w21 + q22 * w22
return result.float()
if __name__ == '__main__':
img = Image.open('lena.png')
transform = transforms.Compose([transforms.PILToTensor()])
tensor = transform(img)[None, ...].float() # (1,3,512,512)
m1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
x = m1(tensor)
y = bilinear(tensor, 2, align_corners=True)
print(torch.allclose(x, y)) # False
print(torch.sum((x - y) ** 2)) # 0.0245