How to implement Box upscaling /nearest neighbors in Pytorch?

Dear all, I found it difficult to implement this function written in tensorflow, anyone helps me ?

def upscale2d(x, n):
    """Box upscaling (also called nearest neighbors).
    Args:
    x: 4D tensor in NHWC format.
    n: integer scale (must be a power of 2).
    Returns:
    4D tensor up scaled by a factor n.
    """
    if n == 1:
        return x
    return tf.batch_to_space(tf.tile(x, [n**2, 1, 1, 1]), [[0, 0], [0, 0]], n)
1 Like