Indexing during runtime into a JITed ModuleList


I created a module that holds a list of different modules and randomly applies one of them on each forward path. This is my current (ugly!) implementation:

class OneOf(nn.Module):
    def __init__(self, *args) -> None:

        self.layers = nn.ModuleList(list(args) if len(args) else [nn.Identity()])

        n_layers = len(self.layers)
        self._weights = nn.parameter.Parameter(
            torch.rand(n_layers), requires_grad=False

    def forward(
        self, x: torch.Tensor, sample: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        log_prob = nn.functional.log_softmax(self._weights, dim=0)

        idx = (
            torch.multinomial(torch.exp(log_prob), 1)
            if sample
            else torch.argmax(log_prob)

        y = torch.zeros_like(x)
        for i, f in enumerate(self.layers):  # change me!
            if i == idx:                     # change me!
                y = f(x)

        return y, log_prob[idx]

if __name__ == '__main__':
    f = OneOf()
    f = torch.jit.script(f)

    x = torch.rand(2, 3, 5)
    y, log_prob = f(x)

    assert torch.allclose(x, y)
    assert np.isclose(log_prob.item(), 0.0)

This hideous iteration & idx comparison is my workaround to avoid the indexing with a non-literal number. Is there a better way to do this w/o breaking the compilation to TorchScript?

Thanks for your help!