Implementation of Dropout for sparse input

(Mahsa) #1

I want to implement dropout for sparse input. I know that the implementation in tensorflow is as follow, but I don’t know if there is anyway for implementation in pytorch (the source of the following code is here)

def sparse_dropout(x, keep_prob, noise_shape):
    """Dropout for sparse tensors."""
    random_tensor = keep_prob
    random_tensor += tf.random_uniform(noise_shape)
    dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
    pre_out = tf.sparse_retain(x, dropout_mask)
    return pre_out * (1./keep_prob)