How can I make custom nn.Module with max pooling index?

I would like to custom nn.Module including max unpooling module.

For example, if I declare MyModule1 and 2 which outputs the pooling index beside output signal and takes the pooling index (pool_idx), respectively, as below.

class MyModule1(nn.Module):
    def __init__(self):
    super(MyModule1, self).__init__()
    self.conv = nn.Conv2d(64, 3, 4)
    self.pool = nn.MaxPool2d(3, 3, return_indices = True)

    def foward(self, x):
        x = self.convl(x)
        x, pool_idx = self.pool(x)
        return x, pool_idx

class MyModule2(nn.Module):
    def __init__(self, pool_idx):
    super(MyModule, self).__init__()
    self.pool_idx = pool_idx
    self.conv = nn.ConvTransposed2d(3, 64, 4)
    self.unpool = nn.MaxUnpool2d(3,3)

    def forward(self, x):
        x = self.unpool(x, self.pool_idx)
        x = self.conv(x)
        return x

But, I do not think that the 2nd module accepts the pooling index as a input.
Also, I do not know how to make a bigger module including these two custom modules.

For example,

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.m1 = MyModule1()
        self.m2 = MyModule2()

    def forward(self, x):
        x, pool_idx1, pool_idx2 = self.m1(x)
        x = self.m2(x, pool_idx1, pool_idx2)
        return x

But, it won’t work.
I would greatly appreciate if you could help me to solve this problem.

You shouldn’t save the pooling index in the constructor. This should work:

class MyModule1(nn.Module):
    def __init__(self):
    super(MyModule1, self).__init__()
    self.conv = nn.Conv2d(64, 3, 4)
    self.pool = nn.MaxPool2d(3, 3, return_indices = True)

    def foward(self, x):
        x = self.convl(x)
        x, pool_idx = self.pool(x)
        return x, pool_idx

class MyModule2(nn.Module):
    def __init__(self):
    super(MyModule2, self).__init__()
    self.conv = nn.ConvTransposed2d(3, 64, 4)
    self.unpool = nn.MaxUnpool2d(3,3)

    def forward(self, x, pool_idx):
        x = self.unpool(x, pool_idx)
        x = self.conv(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.m1 = MyModule1()
        self.m2 = MyModule2()

    def forward(self, x):
        x, pool_idx = self.m1(x)
        x = self.m2(x, pool_idx)
        return x

Thanks!!!
It works now.