Infer value of retain_graph and create_graph

During my backward hook, I’d like to know the values of retain_graph and create_graph.
For create_graph, the value can be determined with torch.is_grad_enabled().
Is there a similar function for retain_graph?
If no, is there a workaround?

Here is a code illustration at which point in the code I’d like to determine the value.

from torch import nn, is_grad_enabled, rand, randint


def _hook(module, grad_input, grad_output):
    print(f"{module.__class__.__name__}, retain_grad=???, create_graph={is_grad_enabled()}")


def forward_backward(retain_graph=False, create_graph=False):
    print(f"\ntest: retain_grad={retain_graph}, create_graph={create_graph}")
    loss = loss_fn(model(rand(8, 5)), randint(4, (8,)))
    loss.backward(retain_graph=retain_graph, create_graph=create_graph)


model = nn.Linear(5, 4)
loss_fn = nn.CrossEntropyLoss()
model.register_full_backward_hook(_hook)

forward_backward()
forward_backward(retain_graph=True)
forward_backward(retain_graph=True, create_graph=True)
forward_backward(retain_graph=False, create_graph=True)

with output

test: retain_grad=False, create_graph=False
Linear, retain_grad=???, create_graph=False

test: retain_grad=True, create_graph=False
Linear, retain_grad=???, create_graph=False

test: retain_grad=True, create_graph=True
Linear, retain_grad=???, create_graph=True

test: retain_grad=False, create_graph=True
Linear, retain_grad=???, create_graph=True

The easiest probably is to just ask grad / backward for the parameter it got (after fixing the retain_grad vs. retain_graph typo):

from torch import nn, is_grad_enabled, rand, randint


def _hook(module, grad_input, grad_output):
    fr = ([f[0] for f in inspect.stack() 
           if '/torch/autograd/'in f.filename][-1])  # could be made more precise to check for backward/grad
    retain_graph = (inspect.getargvalues(fr).locals['retain_graph'])
    print(f"{module.__class__.__name__}, {retain_graph=}, create_graph={is_grad_enabled()}")


def forward_backward(retain_graph=False, create_graph=False):
    print(f"\ntest: retain_graph={retain_graph}, create_graph={create_graph}")
    loss = loss_fn(model(rand(8, 5)), randint(4, (8,)))
    loss.backward(retain_graph=retain_graph, create_graph=create_graph)


model = nn.Linear(5, 4)
loss_fn = nn.CrossEntropyLoss()
model.register_full_backward_hook(_hook)

forward_backward()
forward_backward(retain_graph=True)
forward_backward(retain_graph=True, create_graph=True)
forward_backward(retain_graph=False, create_graph=True)

You did not ask whether what you want to do is a good idea. This is good because in all likelihood is is not a good idea. You are creating quite brittle things here from casual observations with a given PyTorch version in very simple use-cases.

Best regards

Thomas