Maybe define each branch as an nn.Sequential
and put them in a list. Then during forward use torch.cat
to concatenate their outputs along some axis. For example,
class InceptionBlock(nn.Module):
def __init__(self, num_in, num_out):
super(InceptionBlock, self).__init__()
self.branches = [
nn.Sequential(
nn.Conv2d(num_in, num_out, kernel_size=1),
nn.ReLU()),
nn.Sequential(
nn.Conv2d(num_in, num_out, kernel_size=1),
nn.ReLU(),
nn.Conv2d(num_out, num_out, kernel_size=3, padding=1),
nn.ReLU()),
...
]
# **EDIT**: need to call add_module
for i, branch in enumerate(self.branches):
self.add_module(str(i), branch)
def forward(self, x):
# Concatenate branch results along channels
return torch.cat([b(x) for b in self.branches], 1)
EDIT: Need to call add_module
to register the branches