Modify Computational Graph

Hi all,

I’d like to do something conceptually simple but not sure how to implement it in practice:

f(x).backward()

grad_x = x.grad

#stuff

f(y) = f(x + u(g)).backward()

grad_g = g.grad

#stuff

I’d like to be able to do this in an on-the-fly manner i.e, without maintaining two separate graphs for f(x) and f(y), as these will be user-specific. Ideally I’d like to be able to do an in-place operation on x: y = x+ u(g) and then re-pass this to f(y). Is this possible in PyTorch?

Here’s an extremely simple example to clarify the issue:


import torch

class TestFunc(torch.nn.Module):
    
    def __init__(self):
        super(TestFunc, self).__init__()
        
        self.alpha = torch.nn.Parameter(torch.tensor([100.0]), requires_grad=True)

    def forward(self, input: torch.Tensor):
        x_1 = torch.roll(input, +1)
        return torch.sum(self.alpha * (x_1 - input ** 2) ** 2 + (1 - input) ** 2, 0)


input = torch.rand(2, requires_grad=True) * 5

model = TestFunc()

model.register_parameter("gamma", torch.nn.Parameter(torch.tensor([0.1], requires_grad=True)))

model.alpha = torch.nn.Parameter(torch.exp(model.gamma), requires_grad=True)

output2 = model(input)

output2.backward()

print (model.alpha)
print (model.alpha.grad)
print (model.gamma)
print (model.gamma.grad)

outputs:

Parameter containing:
tensor([1.1052], requires_grad=True)
tensor([617.8456])
Parameter containing:
tensor([0.1000], requires_grad=True)
None

I’d like to know if there’s a way to get gamma.grad to not be None!

Hi I tried this and it works, but not sure if it is what you want:

import torch

class TestFunc(torch.nn.Module):

    def __init__(self):
        super(TestFunc, self).__init__()

        # self.alpha = torch.nn.Parameter(torch.tensor([100.0]), requires_grad=True)
        self.gamma = torch.nn.Parameter(torch.tensor([0.1], requires_grad=True))

    def forward(self, input: torch.Tensor):
        x_1 = torch.roll(input, +1)
        self.alpha = torch.exp(model.gamma)
        self.alpha.retain_grad()
        return torch.sum(self.alpha * (x_1 - input ** 2) ** 2 + (1 - input) ** 2, 0)


input = torch.rand(2, requires_grad=True) * 5

model = TestFunc()

# model.register_parameter("gamma", torch.nn.Parameter(torch.tensor([0.1], requires_grad=True)))

# model.alpha = torch.exp(model.gamma), requires_grad=True)

output2 = model(input)

output2.backward()

print (model.alpha)
print (model.alpha.grad)
print (model.gamma)
print (model.gamma.grad)

outputs:

tensor([1.1052], grad_fn=<ExpBackward0>)
tensor([361.4180])
Parameter containing:
tensor([0.1000], requires_grad=True)
tensor([399.4287])

This is not what I want.

The model architecture will be unknown to me in my scenario.

I need to be able to modify arbitrary parameters of existing modules in the way that I showed.

Parameter objects are considered graph leaves, so backward in this case wont propagate further to model.gamma because where its used is “below” a leaf node (i.e. model.alpha).

import torch

from collections import OrderedDict

class TestFunc(torch.nn.Module):
    def __init__(self):
        super(TestFunc, self).__init__()
        self.alpha = torch.nn.Parameter(torch.tensor([100.0]), requires_grad=True)
    def forward(self, input: torch.Tensor):
        x_1 = torch.roll(input, +1)
        return torch.sum(self.alpha * (x_1 - input ** 2) ** 2 + (1 - input) ** 2, 0)


input = torch.rand(2, requires_grad=True) * 5

model = TestFunc()

model.register_parameter("gamma", torch.nn.Parameter(torch.tensor([0.1], requires_grad=True)))

params = OrderedDict(model.named_parameters())
params['alpha'] = torch.exp(params['gamma'])
output2 = model(input)

output2.backward()
params['alpha'].backward()


print(model.alpha)
print(model.alpha.grad)
print(model.gamma)
print(model.gamma.grad)

outputs

tensor([100.], requires_grad=True)
tensor([1.9056])
Parameter containing:
tensor([0.1000], requires_grad=True)
tensor([1.1052])

so what I changed was calling params['alpha'].backward() directly, that way it seems to able to reach model.gamma.

I also changed how you compute alpha, since your original code: model.alpha = torch.nn.Parameter(torch.exp(model.gamma), requires_grad=True) only created a new parameter object and therefore didn’t add to the computation graph.

Hope it helps!

thanks for that, but unfortunately I still need something more general than this.

model.alpha could be a tensor too, in which case you can’t call .backwards() on it directly.

I only gave this snippet as an example of the issue, not as a final use case!

output2 is just a tensor as well, so it should be possible

output2 is a scalar variable (e.g a loss function), therefore you can call .backwards() on it.

alpha/beta/gamma/whatever could be a vector-valued or matrix-valued variable, which means you can’t call .backwards() on it unless you loop through all elements in the tensor, which is not feasible.

Yes sorry you are correct. Maybe you can sum or take the mean?