Hi All,
Can someone share a snippet demonstrating how to use mark_non_differentiable from autograd.Function?
Hi All,
Can someone share a snippet demonstrating how to use mark_non_differentiable from autograd.Function?
For example, comparison is non-differentiable, so you could implement a Function for it like that:
class Gt(Function):
def __init__(self, scalar=None):
super(Gt, self).__init__()
self.scalar = scalar
def forward(self, tensor1, tensor2=None):
other = tensor2 if tensor2 is not None else self.scalar
mask = tensor1.gt(other)
self.mark_non_differentiable(mask)
return mask