How to use mark_non_differentiable

Hi All,

Can someone share a snippet demonstrating how to use mark_non_differentiable from autograd.Function?

1 Like

For example, comparison is non-differentiable, so you could implement a Function for it like that:

class Gt(Function):

    def __init__(self, scalar=None):
        super(Gt, self).__init__()
        self.scalar = scalar

    def forward(self, tensor1, tensor2=None):
        other = tensor2 if tensor2 is not None else self.scalar
        mask = tensor1.gt(other)
        self.mark_non_differentiable(mask)
        return mask
1 Like