Overhead from pure python code in forward pass?

When defining forward my lizard brain always reels at code like:

def forward(self, x):
    x = nn.foo(x)
    x = nn.bar(x)
    x = nn.baz(x)
    return x


def forward(self, x):
    return nn.baz(nn.bar(nn.foo(x)) #imagine depper nesting and longer names

and wants to write something like:

from toolz.functoolz import pipe

def forward(self, x):
    return pipe(x, nn.foo, nn.bar, nn.baz) 

or maybe (to structure more complicated models like ResNet or DenseNet):

from toolz.functoolz import compose

class Model(nn.Module):

    def __init__(self)__:
      self.block1 = compose(baz, bar, foo)
      ... # more stuff here

    def forward(self, x):
      x = self.block1(x)
      ... # more stuff here

Now I wonder (since PyTorch uses a dynamic computation graph) whether additional python code (like lambdas and other pure python code as in modules like toolz and functools etc.) incurs some (non-negligible) overhead? This maybe touches upon how the dynamic computation graph is actually build from the python code in forward and I’d really be interested to learn about that.

…all the while my lizard brain hopes it can indeed write more functional code without incurring a performance penalty.


The autograd engine only records the “basic” operations done on Tensors. So any logic that you add around it will not impact the whole autograd engine.
The overhead you will have is the overhead of running more python code. Depending on the size of your net, that can be completely negigeable (big cnns) or significant (very very small nets). But the autograd engine will not be impacted by these changes.

@albanD thank your for your answer:) Can you point me to further information/elaborate a bit more how the PyTorch/Autograd internals work (conceptually) with regard to building the computation graph from the python code?

Basically the idea is that the autograd engine needs to know every operations that you performed to be able to use their backward equivalent to during the backward pass.

The basic operation is an autograd.Function. for which both forward and backward are defined. They are quite hidden to the end user. For example, for the torch.checkpoint method, it is actually using a single Function here.
Assuming your Tensors require grads, when you apply a Function to a Tensor, it will record this and the output Tensor will have grad_fn field that says which Function was last applied to that Tensor. Similarly, the Function will look at it’s inputs and find out what were the previous functions created it own input. By doing this, you obtain a directed acyclic graph of Functions.
The cpp version works the exact same way where each output of a Function will link to that Function’s backward.

You can explore this graph this way:

import torch

a = torch.rand(10, 10, requires_grad=True)

out = a * 2
print(out.grad_fn) # MulBackward

loss = out.sum()
print(loss.grad_fn) # SumBackward
print(loss.grad_fn.next_functions) # ((MulBackward, 0),)
# Returns which functions corresponds to which input:
# Here only one input corresponds to MulBackward
# The 0 means that it was output 0 from this Function
# (mulitplication had a single output)

more_things = loss / out
print(more_things.grad_fn) # DivBackward
print(more_things.grad_fn.next_functions) # ((SumBackward, 0), (MulBackward, 0))
# Returns which functions corresponds to which input:
# First input loss corresponds to SumBackward
# Second input out corresponds to Mul Backward

As you can see, the autograd engine only kicks in at the Function level, and so for your original question, having more convoluted python logic will not change it’s behaviour (as long as you still perform the same operations on your Tensors).

1 Like