How to implement bilinear upsample (align_corners=True)?

I’m wandering how to implement bilinear upsample(align_corners=True).
The code below is my attemption, and the mse is 0.0245.

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import torch
from PIL import Image
from torch import nn
from torchvision import transforms


def bilinear(img, upscale, align_corners=False):
    B, c, h, w = img.shape
    new_h, new_w = h * upscale, w * upscale
    if not align_corners:
        x = torch.clip((torch.arange(new_h, dtype=torch.double, device=img.device) + 0.5) * h / new_h - 0.5, 0, h - 1)
        y = torch.clip((torch.arange(new_w, dtype=torch.double, device=img.device) + 0.5) * w / new_w - 0.5, 0, w - 1)
        x = x[:, None].expand(-1, new_w)
        y = y[None, :].expand(new_h, -1)

        x1 = torch.clip(torch.floor(x), 0, h - 2).long()
        y1 = torch.clip(torch.floor(y), 0, w - 2).long()
        x2 = x1 + 1
        y2 = y1 + 1
    else:
        x = torch.clip(torch.arange(new_h, dtype=torch.double, device=img.device) * (h - 1) / (new_h - 1), 0, h - 1)
        y = torch.clip(torch.arange(new_w, dtype=torch.double, device=img.device) * (w - 1) / (new_w - 1), 0, w - 1)
        x = x[:, None].expand(-1, new_w)
        y = y[None, :].expand(new_h, -1)

        x1 = torch.floor(x).long()
        y1 = torch.floor(y).long()
        x2 = torch.ceil(x).long()
        y2 = torch.ceil(y).long()

    q11 = img[..., x1, y1]
    q12 = img[..., x1, y2]
    q21 = img[..., x2, y1]
    q22 = img[..., x2, y2]

    u = (x2 - x)[None, None, ...]
    v = (y2 - y)[None, None, ...]
    _1_u = 1 - u
    _1_v = 1 - v

    w11 = u * v
    w12 = u * _1_v
    w21 = _1_u * v
    w22 = _1_u * _1_v

    result = q11 * w11 + q12 * w12 + q21 * w21 + q22 * w22

    return result.float()


if __name__ == '__main__':
    img = Image.open('lena.png')
    transform = transforms.Compose([transforms.PILToTensor()])
    tensor = transform(img)[None, ...].float()  # (1,3,512,512)
    m1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    x = m1(tensor)
    y = bilinear(tensor, 2, align_corners=True)
    print(torch.allclose(x, y))  # False
    print(torch.sum((x - y) ** 2))  # 0.0245

This might help to validate your implementation?