In-place modification in backward not allowed with torch.compile

Modifying the gradient in-place does not work in torch.compile. If I clone the incoming gradient, zero out the elements I want to zero out, and then passing that on works. But doesn’t that use too much memory?

What could be a way to achieve this without blowing up memory but still being able to use torch.compile?

class StraightThroughClamp(Function):
    """
    Straight-through clamp function.
    """

    # pylint: disable=abstract-method, redefined-builtin, arguments-differ

    @staticmethod
    def forward(
        ctx: FunctionCtx,
        x_input: Tensor,
        upper_thresh: Tensor,
        lower_thresh: Tensor,
        input_range: Tensor,
    ) -> Tensor:
        clamped_mask = upper_thresh | lower_thresh
        exactly_the_same = x_input.abs() == input_range
        ctx.save_for_backward(clamped_mask & ~exactly_the_same)
        return clamp(x_input, min=-input_range, max=input_range)

    @staticmethod
    # type: ignore[override]
    def backward(ctx: FunctionCtx, d_output: Tensor) -> Tuple[Tensor, None, None, None, None]:
        (clamped_mask,) = ctx.saved_tensors  # type: ignore[attr-defined]
        
        # doesn't this blow up memory?
        d_output_zeroed_out = d_output.clone()
        d_output_zeroed_out[clamped_mask] = 0.0

        # what works, but not when I use torch.compile
        d_output[clamped_mask] = 0.0

        return d_output_zeroed_out, None, None, None, None

Are you using the latest version of PyTorch? We added support to allow mutating tangents in compile here support input mutations on tangents in compile by bdhirsh · Pull Request #141131 · pytorch/pytorch · GitHub.

I would say that in-place modification of gradients is actually unsafe in eager today, for example if you did an add afterwards, the gradient you are receiving will be aliased to another tensor. There are plans to add safe guards to allow doing this safely, however.