How could I build network structure equivelant to old nn.Concat or nn.Parallel?

Maybe define each branch as an nn.Sequential and put them in a list. Then during forward use torch.cat to concatenate their outputs along some axis. For example,

class InceptionBlock(nn.Module):
    def __init__(self, num_in, num_out):
        super(InceptionBlock, self).__init__()
        self.branches = [
            nn.Sequential(
                nn.Conv2d(num_in, num_out, kernel_size=1),
                nn.ReLU()),
            nn.Sequential(
                nn.Conv2d(num_in, num_out, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(num_out, num_out, kernel_size=3, padding=1),
                nn.ReLU()),
            ...
        ]
        # **EDIT**: need to call add_module
        for i, branch in enumerate(self.branches):
            self.add_module(str(i), branch)

    def forward(self, x):
        # Concatenate branch results along channels
        return torch.cat([b(x) for b in self.branches], 1)

EDIT: Need to call add_module to register the branches

3 Likes