You can just extract the common operation from the operation one and pass it as input.
That would be the simplest (and straightforward) thing you can do.
Let me clarify couple of things.
Autograd accumulate gradients over multiple calls of backward.
Think that if you call a submodel twice, you are invoking creating a siamese-like network.
You can do a trick like this:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.flag = False
def common_operation(self, x):
print(f'Common operation done? {self.flag}')
if self.flag:
return self.tmp
self.flag = False
else:
self.tmp = x ** 2
self.flag = True
return self.tmp
def operation1(self, x):
tmp = self.common_operation(x)
print(tmp)
print(tmp._version)
print(id(tmp))
val1 = x - tmp
return val1
def operation2(self, x):
tmp = self.common_operation(x)
print(tmp)
print(tmp._version)
print(id(tmp))
val2 = tmp * x.mean()
return val2
model = MyModel()
x1 = torch.rand(10).requires_grad_()
x2 = torch.rand(10).requires_grad_()
o1 = model.operation1(x1)
o2 = model.operation2(x2)
s2 = o2.mean()
s1 = o1.mean()
s1.backward(retain_graph=True)
print(f'X1 gradient{x1.grad} before X2')
s2.backward()
print(f'X1 gradients{x1.grad}')
print(f'X2 gradients{x2.grad}')
Common operation done? False
tensor([0.2605, 0.1031, 0.8378, 0.8229, 0.9469, 0.3439, 0.9053, 0.2633, 0.5918,
0.9358], grad_fn=<PowBackward0>)
0
139697179178544
Common operation done? True
tensor([0.2605, 0.1031, 0.8378, 0.8229, 0.9469, 0.3439, 0.9053, 0.2633, 0.5918,
0.9358], grad_fn=<PowBackward0>)
0
139697179178544
X1 gradienttensor([-0.0021, 0.0358, -0.0831, -0.0814, -0.0946, -0.0173, -0.0903, -0.0026,
-0.0539, -0.0935]) before X2
X1 gradientstensor([0.0572, 0.0731, 0.0232, 0.0239, 0.0184, 0.0508, 0.0202, 0.0570, 0.0355,
0.0189])
X2 gradientstensor([0.0601, 0.0601, 0.0601, 0.0601, 0.0601, 0.0601, 0.0601, 0.0601, 0.0601,
0.0601])
As you can see the vesion of the tensor is the same, and it’s never modified. Even the object ID remains.
However, contribution to the gradients will depend on both operations.
It really depends on what kind of gradients are you looking for.
Anyway your case sound like you should remove common_op from op1 and op2, and pass the result of common_op as an input to op1 and op2.