Hello,
can I use torch.jit.trace
after I already assigned the optimizer? I am currently met with loss explosion when I add torch.jit.trace
.
Example pseudo code:
model = Model() # Model is an nn.Module
loss_fn = LossFn(model) # LossFn is an nn.Module without any parameters
optimizer = SGD(model.parameters())
batch = next(dataloader)
loss_fn = torch.jit.trace(loss_fn, (batch,))
# Training
for batch in dataloader:
loss = loss_fn(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()