Using nn.function.interpolate() inside nn.Sequential()

You could create a nn.Module with your interpolate function as a workaround:

class Interpolate(nn.Module):
    def __init__(self, size, mode):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.size = size
        self.mode = mode
        
    def forward(self, x):
        x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
        return x
    

class MeeDecoder(torch.nn.Module):
    def __init__(self):
        super(MeeDecoder, self).__init__()

        self.de_layer1 = torch.nn.Conv2d(in_channels=40, out_channels=30, kernel_size=(1,1), stride=1, bias=True)
        self.de_layer2 = torch.nn.Conv2d(in_channels=30, out_channels=20, kernel_size=(1,1), stride=1, bias=True)
        self.de_layer3 = Interpolate(size=(205, 5), mode='bilinear')
        self.de_layer4 = torch.nn.Conv2d(in_channels=20, out_channels=12, kernel_size=(3,3), stride=1, bias=True)
        self.de_layer5 = Interpolate(size=(1025, 15), mode='bilinear')
        self.de_layer6 = torch.nn.Conv2d(in_channels=12, out_channels=1, kernel_size=(1,1), stride=1, bias=True)
        self.de_layerReLU = torch.nn.ReLU()

        self.decoder = torch.nn.Sequential(
            self.de_layer1,
            self.de_layerReLU,
            self.de_layer2,
            self.de_layerReLU,
            self.de_layer3,
            self.de_layer4,
            self.de_layerReLU,
            self.de_layer5,
            self.de_layer6,
            self.de_layerReLU
        )
    
    def forward(self, x):
        return self.decoder(x)
25 Likes