I mean to combine branches of sub networks together. It was usually done by nn.Concat
in Lua Torch.I searched but only find torch.nn.sequential
.
Maybe define each branch as an nn.Sequential
and put them in a list. Then during forward use torch.cat
to concatenate their outputs along some axis. For example,
class InceptionBlock(nn.Module):
def __init__(self, num_in, num_out):
super(InceptionBlock, self).__init__()
self.branches = [
nn.Sequential(
nn.Conv2d(num_in, num_out, kernel_size=1),
nn.ReLU()),
nn.Sequential(
nn.Conv2d(num_in, num_out, kernel_size=1),
nn.ReLU(),
nn.Conv2d(num_out, num_out, kernel_size=3, padding=1),
nn.ReLU()),
...
]
# **EDIT**: need to call add_module
for i, branch in enumerate(self.branches):
self.add_module(str(i), branch)
def forward(self, x):
# Concatenate branch results along channels
return torch.cat([b(x) for b in self.branches], 1)
EDIT: Need to call add_module
to register the branches
I find the previous solution has problems with DataParallel, which complains RuntimeError: tensors are on different GPUs
, even in the bleeding-edge version.
It might be related to https://github.com/pytorch/pytorch/issues/689. A temporary solution is to create a member variable for each branch (e.g., self.b1 = nn.Sequential(...)
), instead of grouping them into a list.
@Cysu or you could use self.branches = nn.ModuleList([...])
. This will ensure that they remain in sync even when used with data parallel.
Oh. That’s perfect! Didn’t aware of this before. Thanks very much.