Best Practice for Conditionals In Forward Pass

Hi, I have a fairly simple question. Which would be the recommended implementation for a model that if passed out_features=1 flattens the output after the forward pass. The first option uses an Identity placeholder and the second an if statement in the forward pass. I find this problem coming up a lot when trying to have models with lots of options to configure. Any advice would help guide my development. Thanks

# Option A: No conditions in forward pass, but a redundant call to nn.Identity()
class Model(Module):
    def __init__(self, out_features):
        self.tail = torch.nn.Identity()
        if out_features == 1:
             self.tail = torch.nn.Flatten()
    def forward(self, x):
        x = ...
        return sel.tail(x)        
# Option B: Conditions in forward pass
class Model(Module):
    def __init__(self, out_features):
        self.out_features = out_features
        
    def forward(self, x):
        x = ...
        if self.out_features == 1:
             return x.flatten()
        return x
1 Like

I wouldn’t worry about this too much performance-wise, but note that there is a way to omit runtime checks with JIT:


class Model(Module):
  flag : jit.Final[bool]
  def __init__(self, out_features):
     ...
     self.flag = out_features > 1
     self.tail = Flatten() if self.flag else None
  def forward(self,x):
     ...
     if self.flag:
       x=self.tail(x)
m = jit.script(Model(1))

this should totally remove ‘if’ code block. Even simpler, (tail is not None) check would also be statically resolved.

cool, I’ll try it out