Torch.jit.trace() only works on example input?

I’ve created a model with a forward function that takes “x” as input (image of size (3,416,416)). I create a trace of the model using: module = torch.jit.trace(model, example_forward_input), then save that model using module.save("model.pt"). Then I load this model trace into an Android application. When I send an input to the model (from the phone) that is identical to the input used as “example_forward_input”, I get the correct result. However, when I use any other input tensor (same shape), I get poor results. Is this supposed to be the behaviour of the trace function? Is there a function that traces a model that can generalize to any inputs? Any guidance would be much appreciated.

For some more detail: This is a YOLOv3 based model that involves detection and classification. The classification with different inputs into the traced model gives similar results to the same inputs in the model. However, the detection locations differ (in w/h especially) when running an input that was not used as an example through the traced model.

EDIT: I’m guessing this is due to the fact that my forward module uses control-flow that is dependent on the input, as outlined here. However, when I try to convert the model to a script module, as outlined on that same page. I get the following error: raise NotSupportedError(ctx_range, _vararg_kwarg_err) torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "C:\Users\isaac\Anaconda3\envs\SNProject\lib\site-packages\torch\nn\modules\module.py", line 85 def forward(self, *input): ~~~~~~ <--- HERE print(torch.__version__) r"""Defines the computation performed at every call. As you can see this is coming from the torch library itself. Any suggestions on how to proceed?

2 Likes

You are right about the input-dependent control flow requiring torch.jit.script instead of torch.jit.trace. Can you link the YoloV3 implementation you’re using so we can reproduce this error?

I solved the problem. I ended up using torch.jit.trace(), but then having my YOLOLayer inherit a ScriptModule: class YOLOLayer(torch.jit.ScriptModule):, and made my forward method:

@torch.jit.script_method def forward(self, x, targets=torch.tensor([]), img_dim=torch.tensor(416)):
and helper method
@torch.jit.script_method def compute_grid_offsets(self, grid_size): decorated with @torch.jit.script_method. After I did this I just went line-by-line fixing any errors that appeared due to incompatibility with the scripting.

3 Likes

Good to hear you fixed it! We changed the API to TorchScript in PyTorch 1.2 to make it easier to use (i.e. you no longer need to change your model to inherit from ScriptModule instead of nn.Module and you don’t need @script_method), you can read more about it here. But this is just sugar over the same thing you’re already doing, so if you already have it working you don’t need to change anything.

Hi, @IsaacBerman

I’m trying to do it but can’t reproduce it for some errors.
Could you share the fixed codes?
Is the code based on https://github.com/eriklindernoren/PyTorch-YOLOv3?

Hi @junjihashimoto,

Yes it is. The only layer you have to change is YOLOLayer. Since this was just for tracing, I commented out the loss calculations. As follows:

@torch.jit.script
def compare_size(size1, size2):
    return size1 != size2

@torch.jit.script
def get_input(x):
    return x

@torch.jit.script
def get_pred_boxes(x, grid):
    return x + grid
@torch.jit.script
def set_grid_size(x):
    return torch.tensor(x.size(2))

@torch.jit.script
def normalize_by_stride(anchors, stride):
    return torch.div(anchors, stride)

`class YOLOLayer(torch.jit.ScriptModule):
“”“Detection layer”""

def __init__(self, anchors, num_classes, img_dim=416):
    super(YOLOLayer, self).__init__()
    self.anchors = torch.tensor(anchors)
    self.num_anchors = len(anchors)
    self.num_classes = num_classes
    self.ignore_thres = 0.5
    self.mse_loss = nn.MSELoss()
    self.bce_loss = nn.BCELoss()
    self.obj_scale = 1
    self.noobj_scale = 100
    self.metrics = {}
    self.img_dim = torch.tensor(img_dim)
    self.grid_size = torch.tensor(0) # grid size
    self.stride = torch.tensor(0)
    self.grid_x = torch.tensor([])
    self.grid_y = torch.tensor([])
    self.scaled_anchors = torch.tensor([])
    self.anchor_w = torch.tensor([])
    self.anchor_h = torch.tensor([])

@torch.jit.script_method
def compute_grid_offsets(self, grid_size):
    self.grid_size = grid_size.float()
    g = self.grid_size.int()

    self.grid_size = self.grid_size.float()

    self.stride = self.img_dim / self.grid_size

    self.stride = self.stride.float()
    
    self.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, int(g.item()), int(g.item())])
    self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, int(g.item()), int(g.item())])

    
    self.scaled_anchors = torch.div(self.anchors, self.stride)

    self.anchor_w = self.scaled_anchors[:, 0:1].reshape(1, self.num_anchors, 1, 1)
    self.anchor_h = self.scaled_anchors[:, 1:2].reshape(1, self.num_anchors, 1, 1)


@torch.jit.script_method
def forward(self, x, targets=torch.tensor([]), img_dim=torch.tensor(416)):    
    self.img_dim = img_dim
    num_samples = x.size(0)

    grid_size = set_grid_size(x)

    
    self.compute_grid_offsets(grid_size)

    prediction = (
        x.view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size)
        .permute(0, 1, 3, 4, 2)
        .contiguous()
    )

    # Get outputs
    x = torch.sigmoid(prediction[..., 0])  # Center x
    y = torch.sigmoid(prediction[..., 1])  # Center y
    w = prediction[..., 2]  # Width
    h = prediction[..., 3]  # Height
    pred_conf = torch.sigmoid(prediction[..., 4])  # Conf
    pred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.

    #print("YOLO_LAYER: {}".format(x[0][0][0]))
    # Add offset and scale with anchors
    pred_boxes = torch.zeros(prediction[..., :4].shape)

    pred_boxes = torch.stack((x.data+self.grid_x,y.data+self.grid_y,torch.exp(w.data)*self.anchor_w,torch.exp(h.data)*self.anchor_h),4)
    output = torch.cat(
        (
            pred_boxes.view(num_samples, -1, 4) * self.stride,
            pred_conf.view(num_samples, -1, 1),
            pred_cls.view(num_samples, -1, self.num_classes),
        ),
        -1,
    )
    #print(output[0][0][0])

    # if targets is None:
    #     return output, 0
    # else:
    #     iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(
    #         pred_boxes=pred_boxes,
    #         pred_cls=pred_cls,
    #         target=targets,
    #         anchors=self.scaled_anchors,
    #         ignore_thres=self.ignore_thres,
    #     )

    #     # Loss : Mask outputs to ignore non-existing objects (except with conf. loss)
    #     loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
    #     loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
    #     loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
    #     loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
    #     loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
    #     loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])
    #     loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
    #     loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask])
    #     total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

    #     # Metrics
    #     cls_acc = 100 * class_mask[obj_mask].mean()
    #     conf_obj = pred_conf[obj_mask].mean()
    #     conf_noobj = pred_conf[noobj_mask].mean()
    #     conf50 = (pred_conf > 0.5).float()
    #     iou50 = (iou_scores > 0.5).float()
    #     iou75 = (iou_scores > 0.75).float()
    #     detected_mask = conf50 * class_mask * tconf
    #     precision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)
    #     recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)
    #     recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)

    #     self.metrics = {
    #         "loss": to_cpu(total_loss).item(),
    #         "x": to_cpu(loss_x).item(),
    #         "y": to_cpu(loss_y).item(),
    #         "w": to_cpu(loss_w).item(),
    #         "h": to_cpu(loss_h).item(),
    #         "conf": to_cpu(loss_conf).item(),
    #         "cls": to_cpu(loss_cls).item(),
    #         "cls_acc": to_cpu(cls_acc).item(),
    #         "recall50": to_cpu(recall50).item(),
    #         "recall75": to_cpu(recall75).item(),
    #         "precision": to_cpu(precision).item(),
    #         "conf_obj": to_cpu(conf_obj).item(),
    #         "conf_noobj": to_cpu(conf_noobj).item(),
    #         "grid_size": grid_size,
    #     }
    #total_loss = 0
    return output, torch.tensor(0)

`

Thank you for your help.
I did it.

Is is possible to deploy a network without classfier layer onto IOS? I just want to output the final features and compare two tensors in some way, e.g. euclidean distance. Did you deploy on IOS?

Sorry for the late reply, this was an Android deployment.

Care to explain how you did this?
i am using the same YOLO implementation, and i am stuck torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "C:\Users\isaac\Anaconda3\envs\SNProject\lib\site-packages\torch\nn\modules\module.py", line 85 def forward(self, *input): ~~~~~~ <--- HERE print(torch.__version__) r"""Defines the computation performed at every call@IsaacBerman