Dear,
I realize the interp with pytorch like this O(n^2)
:
import torch
def bilinear_interpolation(tensor_data, out_dim):
batch, channels, src_h,src_w = tensor_data.shape
dst_h,dst_w = out_dim[1],out_dim[0]
if src_h == dst_h and src_w == dst_w:
return tensor_data.copy()
dst_img = torch.zeros((batch, channels, dst_h,dst_w))
scale_x,scale_y = float(src_w)/dst_w,float(src_h)/dst_h
for h in range(dst_h):
for w in range (dst_w):
.... ....
return dst_img
I have a test for: tensor_size = (1, 3, 224, 224) interp into tensor_size = (1, 3, 448, 448)
it spend the time about 14s
but when I use the torch.nn.Upsample , it spend time about 0.003s.
Why? I think maybe it’s affected by the cycle.
But, I am not find the source code about torch.nn.Upsample.
Who can help me?
I want to rewrite torch.nn.Upsample for my project.
Thank you,Best Wish!!