I need help modifying/rewriting the computation graph.

I have a model f with the following property:

f(d, x1) = x2 => f(d, x2) = x2

This is by construction, with how i compute a normalized p in the code below.

I need to compute the gradient of f w.r.t. d, given that x2 is the input, but at training time i only have x1 available for supervision.

∇d f(d, x2)

The easiest way to achieve it is to just compute the forward pass twice:

∇d f(d, f(d, x_1))

This is however expensive and mostly wasted, as most of the values computed between the two forward passes are identical.

My goal is to modify the computation graph of a forward pass based on x1, and substitute the part of the graph unique to x1 with a partial forward pass based on x2.

My problem is that `torch.Tensor.grad_fn`

is read-only.

My exploration thus far:

```
import torch
from torch import nn
# Input
d = torch.tensor([1, 2, 3]).float()[None, ...]
d = d / d.norm(dim=-1, keepdim=True)
x1 = torch.tensor([4, 5, 6]).float()[None, ...]
# Forward
d.requires_grad = True
m1 = torch.cross(x1, d, dim=-1)
p1 = torch.cat((d, m1), dim=-1)
p1 = p1 + 0 # a no-op "shim" to enable some of the attempted grad_fn replacement tricks below
net = nn.Sequential(
nn.Linear( 6, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 1),
)
y = net(p1)
x2 = d.cross(m1, dim=-1) + d * y
# Loss
# Compute a new p, based on predicted x2
m2 = torch.cross(x2.detach(), d, dim=-1)
p2 = torch.cat((d, m2), dim=-1)
# p1 and p2 are equal, despite x1 and x2 not being equal
assert not torch.allclose(x1, x2), f"\n{x1}\n{x2}"
assert torch.allclose(p1, p2), f"\n{p1}\n{p2}"
# Objective:
# Compute the gradient of y w.r.t. d given x1, do it given x2 being the input
print(p1.grad_fn) # p1 + 0
# <AddBackward0 object at 0x7f93737d0280>
print(p2.grad_fn) # p2
# <CatBackward0 object at 0x7f93737d0280>
print(p1.grad_fn.next_functions)
# ((<CatBackward0 object at 0x7f93737d0fd0>, 0), (None, 0))
print((p2 + 0).grad_fn.next_functions)
# ((<CatBackward0 object at 0x7f93737d0fd0>, 0), (None, 0))
# GOAL:
# I want to replace the CatBackward0 in p1.grad_fn.next_functions with p2.grad_fn
# === Attempt #1:
try:
p1.grad_fn.next_functions[0][0] = p2.grad_fn
# PROBLEM: this ^ operation is illegal:
except Exception as e:
print(f"ERROR: {e.__class__.__name__}: {e}")
# 'tuple' object does not support item assignment
# The tuple is just a copy, and not a view anyway:
assert p1.grad_fn.next_functions is not p1.grad_fn.next_functions
# === Attempt #2:
try:
p1.grad_fn.next_functions = (p2 + 0).grad_fn.next_functions
# PROBLEM: this ^ operation is also illegal:
except Exception as e:
print(f"ERROR: {e.__class__.__name__}: {e}")
# AttributeError: attribute 'next_functions' of 'AddBackward0' objects is not writable
# === Attempt #3::
try:
p1.grad_fn = p2.grad_fn
# PROBLEM: this ^ operation is also illegal:
except Exception as e:
print(f"ERROR: {e.__class__.__name__}: {e}")
# AttributeError: attribute 'grad_fn' of 'torch._C._TensorBase' objects is not writable
# Given that my objective above was achieved:
# Compute y grad w.r.t. d
y_grad = torch.autograd.grad(y, [d], grad_outputs=torch.ones_like(y), create_graph=True)[0].norm()
loss = y_grad.abs().sum()
```

How can i modify grad_fn?