Methods for checking autograd behavior

This is a follow up my earlier post. Here is an example of the sort of thing I want to consider. Consider the scalar case, where
we have:
f(x) = w2 ( z1(w1 x+b1)))+b2
and
z1 = exp(w1 x+b1)

A naive implementation would be:

class ScalarNN1(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)
        self.layer2 = nn.Linear(1, 1)

    def forward(self, x):
        x1 = self.layer1(x)
        z = torch.exp(x1)
        y = self.layer2(z * x1)
        return y

When I do compute gradients with respect to the parameters w1 and b1 (in layer1), the chain rule will back propagate through z. Alternatively, suppose I construct:

class ScalarNN2(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)
        self.layer2 = nn.Linear(1, 1)

    def forward(self, x):
        x1 = self.layer1(x)
        with torch.no_grad():
            z = torch.exp(x1) # this could also be done with detach
        y = self.layer2(z * x1)
        return y

Because of the torch.no_grad(), there will be no back propagation through the z variable.

I consider these two cases because I am considering training routines where different parts of the network are trained differentially, and I want to be able to only take gradients through certain intermediates. While I can check this example by hand, I was hoping to find a way of checking this more generally. I. had hoped to be able to use something graphical like torchview, but that doesn’t seem to do it.

Any suggestions would be appreciated.

Are you looking for something like torchviz GitHub - szagoruyko/pytorchviz: A small package to create visualizations of PyTorch execution graphs?

Another thing that you can do is set the environment variable TORCH_LOGS="+autograd"

e.g., TORCH_LOGS="+autograd" python test.py

This would make autograd log out the nodes the backward is computing

Yes, this seems to be extracting the information that I want.