I don’t have a good mental model of performance in pytorch (or how to profile with GPUs/TPUs/etc). But basically: Does the following pattern have any hidden performance “gotchas” due to the use of the vararg forward?
class MyNet(torch.nn.Module):
def __init__(self, N):
super(MyNet, self).__init__()
self.N = N
self.net1 = nn.Linear(self.N, 1, bias=True)
self.net2 = nn.Linear(self.N+1, 1, bias=True)
def forward(self, *args):
if len(args) == 1:
return self.net1(args[0])
elif len(args) == 2:
return self.net2(torch.cat([args[0], args[1]]))
else:
error("Unsupported number of arguments")
# Usage
ob = MyNet(3)
X = torch.randn(3)
x = torch.randn(1)
ob(X)
ob(x,X)