Performance with variable argument forward functions

I don’t have a good mental model of performance in pytorch (or how to profile with GPUs/TPUs/etc). But basically: Does the following pattern have any hidden performance “gotchas” due to the use of the vararg forward?

class MyNet(torch.nn.Module):
    def __init__(self, N):
        super(MyNet, self).__init__()
        self.N = N
        self.net1 = nn.Linear(self.N, 1, bias=True)
        self.net2 = nn.Linear(self.N+1, 1, bias=True)
        
    def forward(self, *args):
        if len(args) == 1:
            return self.net1(args[0])
        elif len(args) == 2:
            return self.net2(torch.cat([args[0], args[1]]))
        else:
            error("Unsupported number of arguments")

# Usage
ob = MyNet(3)
X = torch.randn(3)
x = torch.randn(1)
ob(X)
ob(x,X)

Nope, majority of time will be spent in inner functions. The use of cat() is much more significant than varargs.

1 Like

Thanks so much. Btw, is there a better way to do the cat in this sort of scenario… Basically to merge several branches of the computational graph in a downstream linear layer?

Well, you can’t avoid merging memory blocks. Pre gradient flow you can sometimes pre-allocate a joined buffer and write to it part by part, e.g. randn(3, out=x[:3]), x[:3].normal_(), x[:3].fill_(0.1) etc.
If you do “multiple branches”, it is usually slower than straightforward memory consolidation with cat().

1 Like