Modifying the gradient in-place does not work in torch.compile. If I clone the incoming gradient, zero out the elements I want to zero out, and then passing that on works. But doesn’t that use too much memory?
What could be a way to achieve this without blowing up memory but still being able to use torch.compile?
class StraightThroughClamp(Function):
"""
Straight-through clamp function.
"""
# pylint: disable=abstract-method, redefined-builtin, arguments-differ
@staticmethod
def forward(
ctx: FunctionCtx,
x_input: Tensor,
upper_thresh: Tensor,
lower_thresh: Tensor,
input_range: Tensor,
) -> Tensor:
clamped_mask = upper_thresh | lower_thresh
exactly_the_same = x_input.abs() == input_range
ctx.save_for_backward(clamped_mask & ~exactly_the_same)
return clamp(x_input, min=-input_range, max=input_range)
@staticmethod
# type: ignore[override]
def backward(ctx: FunctionCtx, d_output: Tensor) -> Tuple[Tensor, None, None, None, None]:
(clamped_mask,) = ctx.saved_tensors # type: ignore[attr-defined]
# doesn't this blow up memory?
d_output_zeroed_out = d_output.clone()
d_output_zeroed_out[clamped_mask] = 0.0
# what works, but not when I use torch.compile
d_output[clamped_mask] = 0.0
return d_output_zeroed_out, None, None, None, None