About torch.nn.Upsample?


I realize the interp with pytorch like this O(n^2)

import torch

def bilinear_interpolation(tensor_data, out_dim):
    batch, channels, src_h,src_w = tensor_data.shape
    dst_h,dst_w = out_dim[1],out_dim[0]

    if src_h == dst_h and src_w == dst_w:
        return tensor_data.copy()

    dst_img = torch.zeros((batch, channels, dst_h,dst_w))

    scale_x,scale_y = float(src_w)/dst_w,float(src_h)/dst_h

    for h in range(dst_h):
        for w in range (dst_w):
                    .... ....
    return dst_img

I have a test for: tensor_size = (1, 3, 224, 224) interp into tensor_size = (1, 3, 448, 448)

it spend the time about 14s

but when I use the torch.nn.Upsample , it spend time about 0.003s.

Why? I think maybe it’s affected by the cycle.

But, I am not find the source code about torch.nn.Upsample.

Who can help me?

I want to rewrite torch.nn.Upsample for my project.

Thank you,Best Wish!!