Using nn.function.interpolate() inside nn.Sequential()

nn.Upsample() is depecated in pytorch version > 0.4.0 in favor of nn.functional.interpolate()
I’m not able to use interpolate() inside nn.Sequential():

Below is my network:

class MeeDecoder(torch.nn.Module):
    def __init__(self):
        super(MeeDecoder, self).__init__()

        self.de_layer1 = torch.nn.Conv2d(in_channels=40, out_channels=30, kernel_size=(1,1), stride=1, bias=True)
        self.de_layer2 = torch.nn.Conv2d(in_channels=30, out_channels=20, kernel_size=(1,1), stride=1, bias=True)
        self.de_layer3 = torch.nn.functional.interpolate(size=(205, 5), mode='bilinear')
        self.de_layer4 = torch.nn.Conv2d(in_channels=20, out_channels=12, kernel_size=(3,3), stride=1, bias=True)
        self.de_layer5 = torch.nn.Upsample(size=(1025, 15), mode='bilinear')
        self.de_layer6 = torch.nn.Conv2d(in_channels=12, out_channels=1, kernel_size=(1,1), stride=1, bias=True)
        self.de_layerReLU = torch.nn.ReLU()

        self.decoder = torch.nn.Sequential(
            self.de_layer1,
            self.de_layerReLU,
            self.de_layer2,
            self.de_layerReLU,
            self.de_layer3,
            self.de_layer4,
            self.de_layerReLU,
            self.de_layer5,
            self.de_layer6,
            self.de_layerReLU
        )
    
    def forward(self, x):
        return self.decoder(x)

Any help.
Thanks

2 Likes

You could create a nn.Module with your interpolate function as a workaround:

class Interpolate(nn.Module):
    def __init__(self, size, mode):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.size = size
        self.mode = mode
        
    def forward(self, x):
        x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
        return x
    

class MeeDecoder(torch.nn.Module):
    def __init__(self):
        super(MeeDecoder, self).__init__()

        self.de_layer1 = torch.nn.Conv2d(in_channels=40, out_channels=30, kernel_size=(1,1), stride=1, bias=True)
        self.de_layer2 = torch.nn.Conv2d(in_channels=30, out_channels=20, kernel_size=(1,1), stride=1, bias=True)
        self.de_layer3 = Interpolate(size=(205, 5), mode='bilinear')
        self.de_layer4 = torch.nn.Conv2d(in_channels=20, out_channels=12, kernel_size=(3,3), stride=1, bias=True)
        self.de_layer5 = Interpolate(size=(1025, 15), mode='bilinear')
        self.de_layer6 = torch.nn.Conv2d(in_channels=12, out_channels=1, kernel_size=(1,1), stride=1, bias=True)
        self.de_layerReLU = torch.nn.ReLU()

        self.decoder = torch.nn.Sequential(
            self.de_layer1,
            self.de_layerReLU,
            self.de_layer2,
            self.de_layerReLU,
            self.de_layer3,
            self.de_layer4,
            self.de_layerReLU,
            self.de_layer5,
            self.de_layer6,
            self.de_layerReLU
        )
    
    def forward(self, x):
        return self.decoder(x)
24 Likes

Thank you @ptrblck. It works. :slight_smile:

Why provide align_corners as an additional argument for Interpolate class may I ask? It wasn’t a passing argument for the deprecated upsampling.

1 Like

You doesn’t specify input

1 Like

the two functions were designed for different cases.obviously, the interpolate class has more wide range of application scenarios.

1 Like

could you pls tell me how can I calculate the size of the “interpolate”??

You are not calculating but defining the size argument for interpolate as it depends on your use case.
E.g. you could use this operator to down/upscale a tensor using a defined shape.