class ResnetEncoder(nn.Module):
"""Pytorch module for a resnet encoder
"""
def __init__(self, num_layers, pretrained, num_input_images=1):
super(ResnetEncoder, self).__init__()
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
resnets = {18: models.resnet18,
34: models.resnet34,
50: models.resnet50,
101: models.resnet101,
152: models.resnet152}
if num_layers not in resnets:
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
if num_input_images > 1:
self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
else:
self.encoder = resnets[num_layers](pretrained)
if num_layers > 34:
self.num_ch_enc[1:] *= 4
def forward(self, input_image):
self.features = []
x = (input_image - 0.45) / 0.225
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
self.features.append(self.encoder.relu(x))
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
self.features.append(self.encoder.layer2(self.features[-1]))
self.features.append(self.encoder.layer3(self.features[-1]))
self.features.append(self.encoder.layer4(self.features[-1]))
return self.features
encoder = ResnetEncoder(18, True )
example = torch.rand(1, 3, 640, 192)
traced_script_module_encoder = torch.jit.trace(encoder.__getattr__('encoder'), example )
traced_script_module.save('encoder_new.pt')
torch.jit.load('encoder_new.pt')
I have tried to convert the model via trace and loaded it back but it returns different shape features as also suggested by the community trace will not work (Tracing doesn’t understand dynamic control flow, so sometimes it will “constant-ify” shapes in your model. Try turning your model in to a ScriptModule and using TorchScript;)
But in order to convert via torch.jit.script I get the following error
TypeError: module, class, method, function, traceback, frame, or code object was expected, got ResnetEncoder
while using below example :
encoder = ResnetEncoder(18, True )
traced_script_module_encoder = torch.jit.script(encoder)
traced_script_module_encoder.save('new-encoder.pt')