Tensorflow Equivalent of Elastic deformation for pytorch

I am trying to implement elastic deformation by sampling control points on a regularly spaces 100*100 grid with σ = 20 with bilinear variant for the image and nearest neighbour for the masks.
I found this implementation for the TensorFlow by DeepMind.
So far, I was using the below implementation as a lambda transformation-

def elastic_transform(image, alpha=1000, sigma=20, spline_order=1, mode='nearest', random_state=np.random):
    """Elastic deformation of image as described in [Simard2003]_.
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
       Convolutional Neural Networks applied to Visual Document Analysis", in
       Proc. of the International Conference on Document Analysis and
       Recognition, 2003.
#     assert image.ndim == 3
    image = np.array(image)
    assert image.ndim == 3
    shape = image.shape[:2]

    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
                         sigma, mode="constant", cval=0) * alpha

    x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
    indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
    result = np.empty_like(image)
    for i in range(image.shape[2]):
        result[:, :, i] = map_coordinates(
            image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
    result = Image.fromarray(result)
    return result

But the issue with the above implementation is it uses spline interpolation.
Is there any tensorflow equivalent for pytorch or is there any other way I can implement my requirement?

1 Like