What is truly happening when we define dynamic graph models?

Hi everyone! I’m currently exploring the possibility of encode a dynamic computational graph with pyTorch and I’m a little confused about what is happening to my “dynamic model”.

As far as I understand, it’s possible to create models where, as instance, the number of layers and/or neurons per layer can change ([reference]) using Python control-flow operators like loops or conditional statements. However, I cannot figure out what it’s happening to the learnable parameters in such dynamic graph.

Just to be clearer, consider this snippet.
Basically, at each forward pass (that is to say, for every batch) we randomly throw a “coin” that will let us lead to different architectures, namely with 0,1,2 or 3 hidden layers.

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H1, H2, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H1)
        self.middle_linear1 = torch.nn.Linear(H1, H2)
        self.middle_linear2 = torch.nn.Linear(H2, H1)
        self.middle_linear3 = torch.nn.Linear(H1, H1)
        self.output_linear = torch.nn.Linear(H1, D_out)

    def forward(self, x):
        x = relu(self.input_linear(x))
        coin = random.randint(0, 3)
        if coin == 1:
            x = relu(self.middle_linear1(x))
        elif coin == 2:
            x = relu(self.middle_linear1(x))
            x = relu(self.middle_linear2(x))
        elif coin == 3:
            x = relu(self.middle_linear1(x))
            x = relu(self.middle_linear2(x))
            x = relu(self.middle_linear3(x))
        else:
            x = relu(self.output_linear(x))
        return F.log_softmax(x, dim=1)

My doubts are the following:

  1. Am I really exploiting pyTorch dynamic graph capability? From my perspective, I’m basically creating a tree-like structure where we are assigning some probability to fall in one branch or another

  2. How the weights matrices are updated?

  3. How will look the final model that I eventually will save for future use?

  4. What is the answer of the previous three questions in this second case?

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = relu(self.input_linear(x))
        coin = random.randint(0, 3)
        for _ in range(coin):
            x = relu(self.middle_linear(x))
        x = relu(self.output_linear(x))
        return F.log_softmax(x, dim=1)

Thanks a lot for your answers in adavanced

Regarding point 2: When you feed a batch to your network, forward (including your dice roll) is called. When you calculate gradients by calling .backward on some scalar value (calculated using your network output), the gradient with respect to the weights that were actually used to compute the output (this depends the outcome of the dice roll) is computed. The gradient with respect to unused weights is not calculated. For example, if you roll coin == 2, for some weight w of self.middle_linear3 you should have w.grad == None.
EDIT: Note that as soon as self.middle_linear3 has previously been used at least once for a forward/backward call of a batch, w.grad will not be None anymore. If it is not used during a forward/backward call, it just won’t be updated/changed by calling .backward (usually w.grad will be zeros since one usually sets all gradients to zero between optimization steps).

Here is a code example hopefully explaining it well:

import torch
import torch.nn as nn
import random


class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.fc_1 = nn.Linear(4, 1)
        self.fc_2 = nn.Linear(4, 1)

    def forward(self, x):
        if random.random() < 0.5:
            x = self.fc_1(x)
        else:
            x = self.fc_2(x)
        return x


net = TestNet()
data = torch.rand(32, 4)
out = net(data)
loss = sum(out)
loss.backward()

# We called forward/backward once, the gradients of the weights (parameters) of
# either fc_1 or fc_2 should be None
print("Weights and gradients after one f/b call")
for param in net.parameters():
    print(param)
    print("Gradient:", param.grad, "\n")

# Let's do 10 more forward/backward steps (setting grad to zero between steps)
for _ in range(10):
    out = net(data)
    loss = sum(out)
    net.zero_grad()
    loss.backward()

# Now (unless we were very unlucky (0.5**10-unlucky))
# all gradients are not None, and only the gradient with respect to the weights
# that were called in the last iteration are non-zero
print("Weights and gradients after 10 f/b calls")
for param in net.parameters():
    print(param)
    print("Gradient:", param.grad, "\n")

Partial answer to point 3: The recommended way is saving only the model weights. So I think this is how it works in your case: By saving the weights only it does not matter what your forward function looks like. Your DynamicNet object knows which modules (5x Linear) it contains of and saves their weights. You could create another Net using a completely different forward function as long as your modules are the same (and are named the same) and should be able to load your saved weights.

1 Like

Thanks for your answer @ptab

So, regarding the point 3: imagine that I’m no longer rolling a dice, but I’m making statements basing on some input properties, something like:

while x.norm(2) < 10:
      x = relu(self.middle_linear1(x))

I will train my module weights following this update(forward) rule. But what if I present to the model a sample that does not satisfy the while condition? I mean, the forward function is something that is only used in training and could be totally different in an eventual future test phase where I’m using the saved model?

The model will have the same parameters always. When you use the forward function in the training phase, you build a dynamic computation graph that will tell the backpropagation algorithm how to measure the new gradients in order to update the parameters. But the parameter instances will remain the same, regardless of their forward flow.

It means that, when you’re using your forward function on the testing phase, it changes the way the parameters are used, but they are there and don’t change.

If you, for some strange reason, change the forward function after training the model, the parameters will still be there and your forward function will define how to use them, but that’s about it.

1 Like