Help making modifications to the computation graph before backprop

I need help modifying/rewriting the computation graph.

I have a model f with the following property:

f(d, x1) = x2 => f(d, x2) = x2

This is by construction, with how i compute a normalized p in the code below.

I need to compute the gradient of f w.r.t. d, given that x2 is the input, but at training time i only have x1 available for supervision.

∇d f(d, x2)

The easiest way to achieve it is to just compute the forward pass twice:

∇d f(d, f(d, x_1))

This is however expensive and mostly wasted, as most of the values computed between the two forward passes are identical.
My goal is to modify the computation graph of a forward pass based on x1, and substitute the part of the graph unique to x1 with a partial forward pass based on x2.

My problem is that torch.Tensor.grad_fn is read-only.

My exploration thus far:

import torch
from torch import nn

# Input

d = torch.tensor([1, 2, 3]).float()[None, ...]
d = d / d.norm(dim=-1, keepdim=True)
x1 = torch.tensor([4, 5, 6]).float()[None, ...]


# Forward

d.requires_grad = True
m1 = torch.cross(x1, d, dim=-1)
p1 = torch.cat((d, m1), dim=-1)

p1 = p1 + 0 # a no-op "shim" to enable some of the attempted grad_fn replacement tricks below

net = nn.Sequential(
    nn.Linear( 6, 10),
    nn.Linear(10, 10),
    nn.Linear(10, 10),
    nn.Linear(10,  1),
)

y = net(p1)
x2 = d.cross(m1, dim=-1) + d * y


# Loss

# Compute a new p, based on predicted x2
m2 = torch.cross(x2.detach(), d, dim=-1)
p2 = torch.cat((d, m2), dim=-1)

# p1 and p2 are equal, despite x1 and x2 not being equal
assert not torch.allclose(x1, x2),  f"\n{x1}\n{x2}"
assert     torch.allclose(p1, p2),  f"\n{p1}\n{p2}"

# Objective:
# Compute the gradient of y w.r.t. d given x1, do it given x2 being the input

print(p1.grad_fn) # p1 + 0
# <AddBackward0 object at 0x7f93737d0280>
print(p2.grad_fn) # p2
# <CatBackward0 object at 0x7f93737d0280>
print(p1.grad_fn.next_functions)
# ((<CatBackward0 object at 0x7f93737d0fd0>, 0), (None, 0))
print((p2 + 0).grad_fn.next_functions)
# ((<CatBackward0 object at 0x7f93737d0fd0>, 0), (None, 0))

# GOAL:
# I want to replace the CatBackward0 in p1.grad_fn.next_functions with p2.grad_fn

# === Attempt #1:
try:
    p1.grad_fn.next_functions[0][0] = p2.grad_fn
    # PROBLEM: this ^ operation is illegal:
except Exception as e:
    print(f"ERROR: {e.__class__.__name__}: {e}")
    # 'tuple' object does not support item assignment
# The tuple is just a copy, and not a view anyway:
assert p1.grad_fn.next_functions is not p1.grad_fn.next_functions

# === Attempt #2:
try:
    p1.grad_fn.next_functions = (p2 + 0).grad_fn.next_functions
    # PROBLEM: this ^ operation is also illegal:
except Exception as e:
    print(f"ERROR: {e.__class__.__name__}: {e}")
    # AttributeError: attribute 'next_functions' of 'AddBackward0' objects is not writable

# === Attempt #3::
try:
    p1.grad_fn = p2.grad_fn
    # PROBLEM: this ^ operation is also illegal:
except Exception as e:
    print(f"ERROR: {e.__class__.__name__}: {e}")
    # AttributeError: attribute 'grad_fn' of 'torch._C._TensorBase' objects is not writable


# Given that my objective above was achieved:

# Compute y grad w.r.t. d
y_grad = torch.autograd.grad(y, [d], grad_outputs=torch.ones_like(y), create_graph=True)[0].norm()
loss = y_grad.abs().sum()

How can i modify grad_fn?