During my backward hook, I’d like to know the values of retain_graph
and create_graph
.
For create_graph
, the value can be determined with torch.is_grad_enabled()
.
Is there a similar function for retain_graph
?
If no, is there a workaround?
Here is a code illustration at which point in the code I’d like to determine the value.
from torch import nn, is_grad_enabled, rand, randint
def _hook(module, grad_input, grad_output):
print(f"{module.__class__.__name__}, retain_grad=???, create_graph={is_grad_enabled()}")
def forward_backward(retain_graph=False, create_graph=False):
print(f"\ntest: retain_grad={retain_graph}, create_graph={create_graph}")
loss = loss_fn(model(rand(8, 5)), randint(4, (8,)))
loss.backward(retain_graph=retain_graph, create_graph=create_graph)
model = nn.Linear(5, 4)
loss_fn = nn.CrossEntropyLoss()
model.register_full_backward_hook(_hook)
forward_backward()
forward_backward(retain_graph=True)
forward_backward(retain_graph=True, create_graph=True)
forward_backward(retain_graph=False, create_graph=True)
with output
test: retain_grad=False, create_graph=False
Linear, retain_grad=???, create_graph=False
test: retain_grad=True, create_graph=False
Linear, retain_grad=???, create_graph=False
test: retain_grad=True, create_graph=True
Linear, retain_grad=???, create_graph=True
test: retain_grad=False, create_graph=True
Linear, retain_grad=???, create_graph=True