Gradient penalty loss with modified weights

Hey, I am currently experiencing an issue with my regularization loss (which is calculated and back propagated through without first being combined with a normal loss) giving me the error ‘trying to backward through the graph a second…’.

I modify the weights using the outputs of another layer and that does not allow me to specify retain_graph=False in torch.autograd.grad().

Here is a simple code example which fails:

w = nn.Parameter(torch.zeros(10, 10))
dense = nn.Linear(10, 10)

x = torch.rand(1, 10).requires_grad_(True)
y = x.matmul(w.t() * dense(x).mean(0)).mean()

grad = torch.autograd.grad(
    outputs=y,
    inputs=x,
    retain_graph=False,
    create_graph=True,
    only_inputs=True
)[0]

grad.mean().backward()

And here is a version that works but without modifying the weight.

w = nn.Parameter(torch.zeros(10, 10))

x = torch.rand(1, 10).requires_grad_(True)
y = x.matmul(w.t()).mean()

grad = torch.autograd.grad(
    outputs=y,
    inputs=x,
    retain_graph=False,
    create_graph=True,
    only_inputs=True
)[0]

grad.mean().backward()

Is there any way for me to run this without using retain_graph=True? Or is there no difference in memory/performance from retaining the graph for my specific example?