How to split a tensor dimension in 2 with JIT?

Basically I am trying to JIT this op:

x.view(*x.shape[:-2], x.shape[-2] // 2, 2)

The idea it to split the last dimension of tensor in 2, and then apply a linear layer with shape 2 -> 16.

Problem is JIT complains about this because it can not compute *x.shape[:-2]:

cannot statically infer the expected size of a list in this context

Is this generally impossible? Perhaps there’s another way to compute the same thing?

In the end I am trying to inline a trained activation like this:

class Activation2(nn.Module):
    def __init__(self, dtype=None):
        super().__init__()
        self.l1 = activation.activation[2]
        self.l2 = activation.activation[4]
        self.l3 = activation.activation[6]
        self.dtype = dtype

    def forward(self, x: Tensor):
        x = x.unsqueeze(-1)
        x = x.view(*x.size()[:-2], x.size(-2) // 2, 2)
        x = x.to(torch.float32)
        x = self.l1(x)
        x = geglu(x)
        x = self.l2(x)
        x = geglu(x)
        x = self.l3(x)
        x = x.to(self.dtype)
        return x.squeeze(-1)

module = Activation2(dtype)
scripted_module = torch.jit.script(module.eval())