You could create a nn.Module
with your interpolate
function as a workaround:
class Interpolate(nn.Module):
def __init__(self, size, mode):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode
def forward(self, x):
x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
return x
class MeeDecoder(torch.nn.Module):
def __init__(self):
super(MeeDecoder, self).__init__()
self.de_layer1 = torch.nn.Conv2d(in_channels=40, out_channels=30, kernel_size=(1,1), stride=1, bias=True)
self.de_layer2 = torch.nn.Conv2d(in_channels=30, out_channels=20, kernel_size=(1,1), stride=1, bias=True)
self.de_layer3 = Interpolate(size=(205, 5), mode='bilinear')
self.de_layer4 = torch.nn.Conv2d(in_channels=20, out_channels=12, kernel_size=(3,3), stride=1, bias=True)
self.de_layer5 = Interpolate(size=(1025, 15), mode='bilinear')
self.de_layer6 = torch.nn.Conv2d(in_channels=12, out_channels=1, kernel_size=(1,1), stride=1, bias=True)
self.de_layerReLU = torch.nn.ReLU()
self.decoder = torch.nn.Sequential(
self.de_layer1,
self.de_layerReLU,
self.de_layer2,
self.de_layerReLU,
self.de_layer3,
self.de_layer4,
self.de_layerReLU,
self.de_layer5,
self.de_layer6,
self.de_layerReLU
)
def forward(self, x):
return self.decoder(x)