I happen to find that torch.jit.trace() will call forward() three times, run the code below:
import torch class MyCell(torch.nn.Module): def __init__(self): super(MyCell, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, x, h): print("execute forward") new_h = torch.tanh(self.linear(x) + h) return new_h, new_h my_cell = MyCell() x = torch.rand(3, 4) h = torch.rand(3, 4) traced_cell = torch.jit.trace(my_cell, (x, h))
The output is
execute forward execute forward execute forward
I also test other model, they all show call forward() three times, so why torch.jit.trace() need call forward() three times, and what is purpose of each time ? Thanks !