# Help making modifications to the computation graph before backprop

I need help modifying/rewriting the computation graph.

I have a model f with the following property:

f(d, x1) = x2 => f(d, x2) = x2

This is by construction, with how i compute a normalized p in the code below.

I need to compute the gradient of f w.r.t. d, given that x2 is the input, but at training time i only have x1 available for supervision.

∇d f(d, x2)

The easiest way to achieve it is to just compute the forward pass twice:

∇d f(d, f(d, x_1))

This is however expensive and mostly wasted, as most of the values computed between the two forward passes are identical.
My goal is to modify the computation graph of a forward pass based on x1, and substitute the part of the graph unique to x1 with a partial forward pass based on x2.

My problem is that `torch.Tensor.grad_fn` is read-only.

My exploration thus far:

``````import torch
from torch import nn

# Input

d = torch.tensor([1, 2, 3]).float()[None, ...]
d = d / d.norm(dim=-1, keepdim=True)
x1 = torch.tensor([4, 5, 6]).float()[None, ...]

# Forward

m1 = torch.cross(x1, d, dim=-1)
p1 = torch.cat((d, m1), dim=-1)

p1 = p1 + 0 # a no-op "shim" to enable some of the attempted grad_fn replacement tricks below

net = nn.Sequential(
nn.Linear( 6, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10,  1),
)

y = net(p1)
x2 = d.cross(m1, dim=-1) + d * y

# Loss

# Compute a new p, based on predicted x2
m2 = torch.cross(x2.detach(), d, dim=-1)
p2 = torch.cat((d, m2), dim=-1)

# p1 and p2 are equal, despite x1 and x2 not being equal
assert not torch.allclose(x1, x2),  f"\n{x1}\n{x2}"
assert     torch.allclose(p1, p2),  f"\n{p1}\n{p2}"

# Objective:
# Compute the gradient of y w.r.t. d given x1, do it given x2 being the input

# <CatBackward0 object at 0x7f93737d0280>
# ((<CatBackward0 object at 0x7f93737d0fd0>, 0), (None, 0))
# ((<CatBackward0 object at 0x7f93737d0fd0>, 0), (None, 0))

# GOAL:

# === Attempt #1:
try:
# PROBLEM: this ^ operation is illegal:
except Exception as e:
print(f"ERROR: {e.__class__.__name__}: {e}")
# 'tuple' object does not support item assignment
# The tuple is just a copy, and not a view anyway:

# === Attempt #2:
try:
# PROBLEM: this ^ operation is also illegal:
except Exception as e:
print(f"ERROR: {e.__class__.__name__}: {e}")
# AttributeError: attribute 'next_functions' of 'AddBackward0' objects is not writable

# === Attempt #3::
try:
# PROBLEM: this ^ operation is also illegal:
except Exception as e:
print(f"ERROR: {e.__class__.__name__}: {e}")
# AttributeError: attribute 'grad_fn' of 'torch._C._TensorBase' objects is not writable

# Given that my objective above was achieved:

# Compute y grad w.r.t. d