I’m trying to convert a model into jit trace and know where the source of my problem is. If you look at the code below
class GLU(nn.Module):
def __init__(self):
super(GLU, self).__init__()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
nc = x.size(1)
nc = int(nc/2)
return x[:, :nc] * self.sigmoid(x[:, nc:])
The compiler tells me I’m not allowed to convert a torch value to python integer
nc = int(nc/2)
TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
nc = int(nc/2)
Yet indices in python cannot receive torch values but with integers. Is there a way to modify the forward function to avoid the tracerWarning? One solution I can think of is to modify the indices through a high level torch functions but don’t know the best place to start looking.