How could I build network structure equivelant to old nn.Concat or nn.Parallel?

I mean to combine branches of sub networks together. It was usually done by nn.Concat in Lua Torch.I searched but only find torch.nn.sequential .

2 Likes

Maybe define each branch as an nn.Sequential and put them in a list. Then during forward use torch.cat to concatenate their outputs along some axis. For example,

class InceptionBlock(nn.Module):
    def __init__(self, num_in, num_out):
        super(InceptionBlock, self).__init__()
        self.branches = [
            nn.Sequential(
                nn.Conv2d(num_in, num_out, kernel_size=1),
                nn.ReLU()),
            nn.Sequential(
                nn.Conv2d(num_in, num_out, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(num_out, num_out, kernel_size=3, padding=1),
                nn.ReLU()),
            ...
        ]
        # **EDIT**: need to call add_module
        for i, branch in enumerate(self.branches):
            self.add_module(str(i), branch)

    def forward(self, x):
        # Concatenate branch results along channels
        return torch.cat([b(x) for b in self.branches], 1)

EDIT: Need to call add_module to register the branches

3 Likes

I think you might find this thread helpful.

2 Likes

I find the previous solution has problems with DataParallel, which complains RuntimeError: tensors are on different GPUs, even in the bleeding-edge version.

It might be related to https://github.com/pytorch/pytorch/issues/689. A temporary solution is to create a member variable for each branch (e.g., self.b1 = nn.Sequential(...)), instead of grouping them into a list.

@Cysu or you could use self.branches = nn.ModuleList([...]). This will ensure that they remain in sync even when used with data parallel.

2 Likes

Oh. That’s perfect! Didn’t aware of this before. Thanks very much.