Are there utilities like concattable, paralleltable in pytorch??
1 Like
No, but you can easily achieve that using autograd.
e.g. for concat table:
class MyModel(nn.Module):
def __init__(self):
self.submodule1 = MySubmodule1()
self.submodule2 = MySubmodule2()
def forward(self, x):
# this is equivalent to ConcatTable containing both submodules,
# but is much more readable
concat = [self.submodule1(x), self.submodule2(x)]
# do some other processing on concat...
2 Likes
Hi, what if the number of submodule
is variable based on the user?
I tried the following
class MyModel(nn.Module):
def __init__(self, submod_num,submod_var):
self.submod_pool = []
for i in range(submod_num):
self.submod_pool += [MySubmodule(submod_var[i])]
def forward(self,X):
outPool = []
for submod in self.submod_pool:
outPool += [submod(X)]
return torch.cat(outPool,dim)
But I’m not sure if it works because the code compile fine but the training error is not as I expected. Then I found tensorboardX, when I used it to draw the graph, it returns an error which clearly indicates that the graph is disconnected
@Amir_Ghodrati sorry for late reply, I didn’t notice there was a mention.
I solved it by this:
model= nn.Sequential(*layer_list)