I would like to custom nn.Module including max unpooling module.
For example, if I declare MyModule1 and 2 which outputs the pooling index beside output signal and takes the pooling index (pool_idx), respectively, as below.
class MyModule1(nn.Module):
def __init__(self):
super(MyModule1, self).__init__()
self.conv = nn.Conv2d(64, 3, 4)
self.pool = nn.MaxPool2d(3, 3, return_indices = True)
def foward(self, x):
x = self.convl(x)
x, pool_idx = self.pool(x)
return x, pool_idx
class MyModule2(nn.Module):
def __init__(self, pool_idx):
super(MyModule, self).__init__()
self.pool_idx = pool_idx
self.conv = nn.ConvTransposed2d(3, 64, 4)
self.unpool = nn.MaxUnpool2d(3,3)
def forward(self, x):
x = self.unpool(x, self.pool_idx)
x = self.conv(x)
return x
But, I do not think that the 2nd module accepts the pooling index as a input.
Also, I do not know how to make a bigger module including these two custom modules.
For example,
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.m1 = MyModule1()
self.m2 = MyModule2()
def forward(self, x):
x, pool_idx1, pool_idx2 = self.m1(x)
x = self.m2(x, pool_idx1, pool_idx2)
return x
But, it won’t work.
I would greatly appreciate if you could help me to solve this problem.