I happen to find that torch.jit.trace() will call forward() three times, run the code below:
import torch
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
print("execute forward")
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
The output is
execute forward
execute forward
execute forward
I also test other model, they all show call forward() three times, so why torch.jit.trace() need call forward() three times, and what is purpose of each time ? Thanks !