Hi, I have a fairly simple question. Which would be the recommended implementation for a model that if passed out_features=1
flattens the output after the forward pass. The first option uses an Identity placeholder and the second an if statement in the forward pass. I find this problem coming up a lot when trying to have models with lots of options to configure. Any advice would help guide my development. Thanks
# Option A: No conditions in forward pass, but a redundant call to nn.Identity()
class Model(Module):
def __init__(self, out_features):
self.tail = torch.nn.Identity()
if out_features == 1:
self.tail = torch.nn.Flatten()
def forward(self, x):
x = ...
return sel.tail(x)
# Option B: Conditions in forward pass
class Model(Module):
def __init__(self, out_features):
self.out_features = out_features
def forward(self, x):
x = ...
if self.out_features == 1:
return x.flatten()
return x