Implementation of Dropout for sparse input

I want to implement dropout for sparse input. I know that the implementation in tensorflow is as follow, but I don’t know if there is anyway for implementation in pytorch (the source of the following code is here)

def sparse_dropout(x, keep_prob, noise_shape):
    """Dropout for sparse tensors."""
    random_tensor = keep_prob
    random_tensor += tf.random_uniform(noise_shape)
    dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
    pre_out = tf.sparse_retain(x, dropout_mask)
    return pre_out * (1./keep_prob)

I am doing the same thing currently
here is my solution

class SparseDropout(torch.nn.Module):
    def __init__(self, dprob=0.5):
        super(SparseDropout, self).__init__()
        # dprob is ratio of dropout
        # convert to keep probability

    def forward(self, x):
        return torch.sparse.FloatTensor(rc, val)

Hi KuanS, thanks for the solution. Currently I am working on solving the same function. Sadly, the proposal is not working as a simple function. Do you have any idea on how to access the rc=x.indices()[:,mask] when the x features is not a sparse matrix? In the case of transforming it to a sparse matrix then the indices are in 2nd dimension. Moreover, both functionalities x._values() and x._indices() must be deprecated because an error appears.Thank you in advance