Basically I am trying to JIT this op:
x.view(*x.shape[:-2], x.shape[-2] // 2, 2)
The idea it to split the last dimension of tensor in 2, and then apply a linear layer with shape 2 -> 16
.
Problem is JIT complains about this because it can not compute *x.shape[:-2]
:
cannot statically infer the expected size of a list in this context
Is this generally impossible? Perhaps there’s another way to compute the same thing?
In the end I am trying to inline a trained activation like this:
class Activation2(nn.Module):
def __init__(self, dtype=None):
super().__init__()
self.l1 = activation.activation[2]
self.l2 = activation.activation[4]
self.l3 = activation.activation[6]
self.dtype = dtype
def forward(self, x: Tensor):
x = x.unsqueeze(-1)
x = x.view(*x.size()[:-2], x.size(-2) // 2, 2)
x = x.to(torch.float32)
x = self.l1(x)
x = geglu(x)
x = self.l2(x)
x = geglu(x)
x = self.l3(x)
x = x.to(self.dtype)
return x.squeeze(-1)
module = Activation2(dtype)
scripted_module = torch.jit.script(module.eval())