I’m trying to convert a model into jit trace and know where the source of my problem is. If you look at the code below
class GLU(nn.Module): def __init__(self): super(GLU, self).__init__() self.sigmoid = nn.Sigmoid() def forward(self, x): nc = x.size(1) nc = int(nc/2) return x[:, :nc] * self.sigmoid(x[:, nc:])
The compiler tells me I’m not allowed to convert a torch value to python integer
nc = int(nc/2)
TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! nc = int(nc/2)
Yet indices in python cannot receive torch values but with integers. Is there a way to modify the forward function to avoid the tracerWarning? One solution I can think of is to modify the indices through a high level torch functions but don’t know the best place to start looking.