Add argmax in a model that wraps up a pytorch pretrained model (deeplabv3)

Is there any possibility of adding an argmax operation at the end of the pretrained deeplabv3_mobilenet_v3_large pytorch model?
Basically I want to know if this thing is possible without any training, just squeezing the original output of the model from [1,21,513,513] tensor output to [1,513,513] to decrease the grpc latency on a model-server solution.

I tried something like this in the code below. I managed to include the normalization in the model but the argmax is throwing and error I don’t know how to handle.

Code:

class ArgMax(torch.nn.Module):
    def __init__(self):
        super(ArgMax,self).__init__()

    def forward(self, input):
        x = torch.argmax(input, dim=1)
        x = x.to(torch.int8)
        return x    


class Normalize(torch.nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = torch.tensor(mean).reshape(1,3,1,1)
        self.std = torch.tensor(std).reshape(1,3,1,1)

    def forward(self, input):
        x = input / 255.0
        x = x - self.mean
        x = x / self.std
        return x

new_model = torch.nn.Sequential(
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    model,
    ArgMax()
)

Error:

forward(__torch__.ArgMax self, Tensor x) -> (Tensor):
Expected a value of type 'Tensor (inferred)' for argument 'x' but instead found type 'Dict[str, Tensor]'.
Inferred 'x' to be of type 'Tensor' because it was not annotated with an explicit type.

I don’t know if I am doing the things right or if this thing is possible, so any input in very welcomed!
Thanks!

Based on the error message it seems that your custom ArgMax module expects a tensor input while the model returns a dict.
You could change your custom module to accept the output dict and index the desired outputs.

1 Like

I found out after searching about the dict you mentioned. Now the custom ArgMax looks like this:

class ArgMax(torch.nn.Module):
    def __init__(self):
        super(ArgMax,self).__init__()

    def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
        t = x["out"]        
        t = torch.argmax(t, dim=1)
        t = t.to(torch.int8)
        return t

Thanks for your help!