Share graph between training and inference

I have a model with a forward pass which shares the same computations in training and inference mode. I don’t want to use a boolean training or something like that to declare if I am currently training. I want a forward pass with two outputs. The first one being the training output and the second one the inference output. Depending on how I call the forward pass / which of the two outputs I pick, I want to compute only the part of the graph actually needed for the output. Sometimes I want to calculate both, training and inference output with one forward pass. I want to have a clean model being able to cope with all three situations efficiently. Below there is a dummy model.

class Model(torch.nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(3, 3, 3)
        self.conv2d_2 = torch.nn.Conv2d(3, 3, 3)
        self.conv2d_3 = torch.nn.Conv2d(3, 10, 3)
        self.conv2d_4 = torch.nn.Conv2d(10, 10, 3)
        self.conv2d_5 = torch.nn.Conv2d(10, 10, 3)

    def forward(self, input):
        train_out = self.conv2d_1(input)
        inference_out = self.conv2d_2(input)

        train_out = self.conv2d_3(train_out)
        inference_out = self.conv2d_3(inference_out)

        train_out = self.conv2d_4(train_out)
        inference_out = self.conv2d_5(inference_out)

        return train_out, inference_out

if __name__ == '__main__':

    model = Model()
    training_input = torch.rand(1, 3, 100, 100)
    inference_input = torch.rand(1, 3, 200, 200)
    training_out, _ = model(training_input)
    _ , inference_out = model(inference_input)

    training_inference_input = torch.rand(1, 3, 100, 100)
    training_out, inference_out = model(training_inference_input)

In this example they share the convolution self.conv2d_3, in reality the model is much more complex. If I call the model like this, still the whole graph gets computed even if I only take one of the outputs. What is the best way to achieve only the part of graph being computed which is needed in the output?

The best (and probably only way) would be to use conditions, since the computation graph is created dynamically during the forward pass.
What’s the reason you don’t want to use conditions?

@ptrblck thank you for the answer! I wanted to avoid conditions as it would lead to a bunch of them being in the forward pass. I just wanted to make sure if this is the best solution for such a problem.

The computation graph is constructed dynamically during the forward execution, so I don’t think there is a valid workaround by e.g. just using the outputs.
Maybe a scripted model would be able to skip “unused” parts of the model, but I’m not familiar enough with the JIT.

1 Like