Unable to load torch.nn.Module on sagemaker

I created a model using torch.nn.Module inheritance below

class Model(torch.nn.Module):
    def __init__(self, model, logit_layer):
        super(Model, self).__init__()
        self.model = model
        self.logit_layer = logit_layer

    def forward(self, input_ids, mask, tokens):
        pre_logits_mask = torch.reshape(mask, (mask.shape[0], mask.shape[1], 1) )
        outputs = self.model(input_ids, attention_mask=mask, token_type_ids=tokens)
        last_hidden_state = outputs.last_hidden_state
        last_hidden_state_zero_layer = torch.mul(last_hidden_state, pre_logits_mask)
        summed_final_hidden_state = torch.sum(last_hidden_state_zero_layer, 1)
        logits = self.logit_layer(summed_final_hidden_state)
        probs = torch.sigmoid(logits)
        return probs

I works fine when saving and loading locally, but sagemaker is unable to load this object when I use sagemakers PyTorch wrapper function. When I checked the type of my above model, it read

Out[116]: __main__.Model

So it doesn’t look like the above model I created is technically a PyTorch object, which may be the cause of the sagemaker model not loading error. Is there a workaround for this?