JIT script error when Sequential container takes a Tuple input

This is a simple net to reproduce my error. I’m passing a Tuple to the forward method and have specified the typing. I think the error is caused by Jit inferring the input type to Sequential’s forward method to be a Tensor, and not a Tuple. How can I fix this error?

class MyBatchNorm(nn.Module):
    def __init__(self, output_size, d_ids):
        super().__init__()
        self.d_ids = d_ids
        self.net = nn.ModuleDict({f"{d}": nn.BatchNorm1d(output_size) for d in d_ids})
    
    def forward(self, input_tuple: Tuple[torch.Tensor, int]) -> Tuple[torch.Tensor, int]:
        input_tensor, d = input_tuple
        output_tensor = torch.tensor([])
        for d_name, d_norm in self.net.items():
            if f"{d}" == d_name:
                output_tensor = d_norm(input_tensor)
        if len(output_tensor) == 0:
            raise ValueError(f"invalid d {d}, must be {self.d_ids}")
        return output_tensor, d

class MyNet(nn.Module):
    def __init__(self, output_size, d_ids):
        super().__init__()
        dense_layers = [
            MyBatchNorm(output_size, d_ids),
            MyBatchNorm(output_size, d_ids)
        ]
        self.net = torch.nn.Sequential(*dense_layers)
        
    def forward(self, input_tensor: torch.Tensor, d_tensor: torch.Tensor) -> torch.Tensor:
        d = d_tensor.squeeze()[0].item()
        output_tensor, _ = self.net((input_tensor, d))
        return torch.squeeze(output_tensor)

Error:

RuntimeError: 

forward(__torch__.___torch_mangle_16.MyBatchNorm self, (Tensor, int) input_tuple) -> ((Tensor, int)):
Expected a value of type 'Tuple[Tensor, int]' for argument 'input_tuple' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'input_tuple' to be of type 'Tensor' because it was not annotated with an explicit type.
:
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/container.py", line 117
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input