When/why would I maybe get memory leak when storing tensors in ctx?

For example, via: How to transition to functions not being allowed to have member variables - #2 by albanD

There is this example code:

class YourFn(Function):
  @staticmethod
    def forward(ctx, arg1, arg2, my_state):
      # This assumes that my_state is NOT a Tensor
      # If it is, you have to use ctx.save_for_backward()
      # or you will see a memory leak
      ctx.my_state = my_state
      # compute the output
      return output

  @staticmethod
  def backward(ctx, grad_output):
    my_state = ctx.my_state
    # compute grad1, grad2
    return grad1, grad2, None

fn_state = {}
output = YourFn.apply(arg1, arg2, fn_state)

It says I will see a memory leak when I put some tensor in fn_state (when I don’t use save_for_backward).

Why is that the case? How can I avoid it when I must store some tensor in fn_state?


To expand, my use case, via CTCLoss gradient is incorrect · Issue #52241 · pytorch/pytorch · GitHub

def torch_ctc_fixed_grad(
    log_probs: torch.Tensor,
    targets: torch.Tensor,
    input_lengths: torch.Tensor,
    target_lengths: torch.Tensor,
    *args,
    **kwargs,
) -> torch.Tensor:
    log_probs, loss_scale_buffer = _FixCTCGradFunc.apply(log_probs, input_lengths)
    loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, *args, **kwargs)
    loss = _StoreGradScaleFunc.apply(loss, loss_scale_buffer)
    return loss


class _FixCTCGradFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, log_probs, input_lengths):
        loss_scale_buffer = {}
        ctx.loss_scale_buffer = loss_scale_buffer
        ctx.save_for_backward(log_probs, input_lengths)
        return log_probs, loss_scale_buffer

    @staticmethod
    def backward(ctx, grad_output, _grad_scale):
        loss_scale_buffer = ctx.loss_scale_buffer
        (log_probs, input_lengths) = ctx.saved_tensors
        assert isinstance(loss_scale_buffer, dict) and set(loss_scale_buffer.keys()) == {"scale"}
        # Pop so that we avoid any potential memory leaks.
        loss_scale_buffer: torch.Tensor = loss_scale_buffer.pop("scale")

        # The ctc_loss calculates (exp(log_probs) - y) * scale,
        # where y are the soft targets,
        # and where we control scale=1 via _StoreGradScaleFunc.
        # We want to return -y * loss_scale_buffer instead.
        # Thus, subtract the exp(log_probs) from the grad_output.
        grad_input = grad_output - log_probs.exp()  # [T, N, C]
        if loss_scale_buffer.ndim == 1:
            grad_input.multiply_(loss_scale_buffer[None, :, None])
        else:
            grad_input.multiply_(loss_scale_buffer)
        input_lengths = input_lengths.to(grad_input.device)
        max_time = grad_input.shape[0]
        mask = torch.arange(max_time, device=input_lengths.device)[:, None] < input_lengths[None, :]  # [T, N]
        grad_input = torch.where(mask[:, :, None], grad_input, torch.zeros_like(grad_input))

        return grad_input, None


class _StoreGradScaleFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, loss, loss_scale_buffer):
        ctx.loss_scale_buffer = loss_scale_buffer
        return loss.clone()

    @staticmethod
    def backward(ctx, grad_output):
        loss_scale_buffer = ctx.loss_scale_buffer
        assert not loss_scale_buffer
        loss_scale_buffer["scale"] = grad_output
        return torch.ones_like(grad_output), None

Is this safe? Would I get memory leaks here or not? How can I avoid memory leaks? Or if this is unsafe in other ways, how to fix it?

If you save a forward output of the custom autograd Function via ctx, that output references to its own grad_fn, which then saves/references the output, creating a cycle. ctx.save_for_backward prevents this by automatically stashes and restores the grad_fn, in order to avoid the cycle.

You can work around this by detaching those tensors before stashing, but this means that higher-order gradients will not work.

Separately, benefit of storing large tensors via ctx.save_for_backward wrt memory is that when you backward, those tensors will be cleared as soon as the backward computation finishes. Versus saving on ctx means that those tensors are kept alive as long as the autograd graph is alive.