For example, via: How to transition to functions not being allowed to have member variables - #2 by albanD
There is this example code:
class YourFn(Function):
@staticmethod
def forward(ctx, arg1, arg2, my_state):
# This assumes that my_state is NOT a Tensor
# If it is, you have to use ctx.save_for_backward()
# or you will see a memory leak
ctx.my_state = my_state
# compute the output
return output
@staticmethod
def backward(ctx, grad_output):
my_state = ctx.my_state
# compute grad1, grad2
return grad1, grad2, None
fn_state = {}
output = YourFn.apply(arg1, arg2, fn_state)
It says I will see a memory leak when I put some tensor in fn_state
(when I don’t use save_for_backward
).
Why is that the case? How can I avoid it when I must store some tensor in fn_state
?
To expand, my use case, via CTCLoss gradient is incorrect · Issue #52241 · pytorch/pytorch · GitHub
def torch_ctc_fixed_grad(
log_probs: torch.Tensor,
targets: torch.Tensor,
input_lengths: torch.Tensor,
target_lengths: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
log_probs, loss_scale_buffer = _FixCTCGradFunc.apply(log_probs, input_lengths)
loss = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, *args, **kwargs)
loss = _StoreGradScaleFunc.apply(loss, loss_scale_buffer)
return loss
class _FixCTCGradFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, log_probs, input_lengths):
loss_scale_buffer = {}
ctx.loss_scale_buffer = loss_scale_buffer
ctx.save_for_backward(log_probs, input_lengths)
return log_probs, loss_scale_buffer
@staticmethod
def backward(ctx, grad_output, _grad_scale):
loss_scale_buffer = ctx.loss_scale_buffer
(log_probs, input_lengths) = ctx.saved_tensors
assert isinstance(loss_scale_buffer, dict) and set(loss_scale_buffer.keys()) == {"scale"}
# Pop so that we avoid any potential memory leaks.
loss_scale_buffer: torch.Tensor = loss_scale_buffer.pop("scale")
# The ctc_loss calculates (exp(log_probs) - y) * scale,
# where y are the soft targets,
# and where we control scale=1 via _StoreGradScaleFunc.
# We want to return -y * loss_scale_buffer instead.
# Thus, subtract the exp(log_probs) from the grad_output.
grad_input = grad_output - log_probs.exp() # [T, N, C]
if loss_scale_buffer.ndim == 1:
grad_input.multiply_(loss_scale_buffer[None, :, None])
else:
grad_input.multiply_(loss_scale_buffer)
input_lengths = input_lengths.to(grad_input.device)
max_time = grad_input.shape[0]
mask = torch.arange(max_time, device=input_lengths.device)[:, None] < input_lengths[None, :] # [T, N]
grad_input = torch.where(mask[:, :, None], grad_input, torch.zeros_like(grad_input))
return grad_input, None
class _StoreGradScaleFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, loss, loss_scale_buffer):
ctx.loss_scale_buffer = loss_scale_buffer
return loss.clone()
@staticmethod
def backward(ctx, grad_output):
loss_scale_buffer = ctx.loss_scale_buffer
assert not loss_scale_buffer
loss_scale_buffer["scale"] = grad_output
return torch.ones_like(grad_output), None
Is this safe? Would I get memory leaks here or not? How can I avoid memory leaks? Or if this is unsafe in other ways, how to fix it?