Why cannot torch.jit accelerate training speed?

Hi,

I am new to the feature of torch.jit, after reading the docs about this feature, I have some trouble understanding the philosophy of torch.jit.

I have a model implemented with nn.Module like this:

    class Model(nn.Module):
        def __init__(self):
            self.conv = nn.Conv2d ...

        def forward(self, x)
             return self.conv(x)

Then I wrap it with jit like this:

    class ModelJIT(ScriptModule):
        def __init__(self):
            self.model = trace(Model(), torch.randn(16, 3, 256, 128))
        @script_method
        def forward(self, x):
            return self.model(x)

Then rather than creating the model with class Model, I create with ModelJIT to train.

# net = Model()
net = ModelJIT()
net.cuda()
net.train()
loss_obj = nn.CrossEntropyLoss()

optim = SGD(...)

for im, lb in dataloader:
     optim.zero_grad()
     logits = net(im.cuda())
     loss = loss_obj(logits, lb)
     loss.backward()
     optim.step()

With these code, I can make my program run, however, I only notice a much slower setup of my program without any speed improvement. So is torch.jit designed to accelerate the program, or are there any mistakes of my implementation ? Please show me the correct way to work with torch.jit ?

torch.jit is not exactly for accelreating speed, rather optimizing the script. It is a way to create serializable and optimizable models from PyTorch code. Any code written in it can be saved from your Python process and loaded in a process where there is no Python dependency.

It provide tools to incrementally transition a model from being a pure Python program to a TorchScript program that can be run independently from Python, for instance, in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools and then export the model to a production environment where it is not a good idea to run models as Python programs for performance and multi-threading reasons.

3 Likes

Thanks, I notice that when employing torch.jit.trace, a field of a random tensor must be assigned, what is the meaning of this field? After assigning this parameter, must I fix the shape of input tensor as the shape of the example input when I call it in c++ ?