Hi,
I am new to the feature of torch.jit, after reading the docs about this feature, I have some trouble understanding the philosophy of torch.jit.
I have a model implemented with nn.Module
like this:
class Model(nn.Module):
def __init__(self):
self.conv = nn.Conv2d ...
def forward(self, x)
return self.conv(x)
Then I wrap it with jit
like this:
class ModelJIT(ScriptModule):
def __init__(self):
self.model = trace(Model(), torch.randn(16, 3, 256, 128))
@script_method
def forward(self, x):
return self.model(x)
Then rather than creating the model
with class Model
, I create with ModelJIT
to train.
# net = Model()
net = ModelJIT()
net.cuda()
net.train()
loss_obj = nn.CrossEntropyLoss()
optim = SGD(...)
for im, lb in dataloader:
optim.zero_grad()
logits = net(im.cuda())
loss = loss_obj(logits, lb)
loss.backward()
optim.step()
With these code, I can make my program run, however, I only notice a much slower setup of my program without any speed improvement. So is torch.jit
designed to accelerate the program, or are there any mistakes of my implementation ? Please show me the correct way to work with torch.jit
?