Hi @junjihashimoto,
Yes it is. The only layer you have to change is YOLOLayer. Since this was just for tracing, I commented out the loss calculations. As follows:
@torch.jit.script
def compare_size(size1, size2):
return size1 != size2
@torch.jit.script
def get_input(x):
return x
@torch.jit.script
def get_pred_boxes(x, grid):
return x + grid
@torch.jit.script
def set_grid_size(x):
return torch.tensor(x.size(2))
@torch.jit.script
def normalize_by_stride(anchors, stride):
return torch.div(anchors, stride)
`class YOLOLayer(torch.jit.ScriptModule):
“”“Detection layer”""
def __init__(self, anchors, num_classes, img_dim=416):
super(YOLOLayer, self).__init__()
self.anchors = torch.tensor(anchors)
self.num_anchors = len(anchors)
self.num_classes = num_classes
self.ignore_thres = 0.5
self.mse_loss = nn.MSELoss()
self.bce_loss = nn.BCELoss()
self.obj_scale = 1
self.noobj_scale = 100
self.metrics = {}
self.img_dim = torch.tensor(img_dim)
self.grid_size = torch.tensor(0) # grid size
self.stride = torch.tensor(0)
self.grid_x = torch.tensor([])
self.grid_y = torch.tensor([])
self.scaled_anchors = torch.tensor([])
self.anchor_w = torch.tensor([])
self.anchor_h = torch.tensor([])
@torch.jit.script_method
def compute_grid_offsets(self, grid_size):
self.grid_size = grid_size.float()
g = self.grid_size.int()
self.grid_size = self.grid_size.float()
self.stride = self.img_dim / self.grid_size
self.stride = self.stride.float()
self.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, int(g.item()), int(g.item())])
self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, int(g.item()), int(g.item())])
self.scaled_anchors = torch.div(self.anchors, self.stride)
self.anchor_w = self.scaled_anchors[:, 0:1].reshape(1, self.num_anchors, 1, 1)
self.anchor_h = self.scaled_anchors[:, 1:2].reshape(1, self.num_anchors, 1, 1)
@torch.jit.script_method
def forward(self, x, targets=torch.tensor([]), img_dim=torch.tensor(416)):
self.img_dim = img_dim
num_samples = x.size(0)
grid_size = set_grid_size(x)
self.compute_grid_offsets(grid_size)
prediction = (
x.view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size)
.permute(0, 1, 3, 4, 2)
.contiguous()
)
# Get outputs
x = torch.sigmoid(prediction[..., 0]) # Center x
y = torch.sigmoid(prediction[..., 1]) # Center y
w = prediction[..., 2] # Width
h = prediction[..., 3] # Height
pred_conf = torch.sigmoid(prediction[..., 4]) # Conf
pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.
#print("YOLO_LAYER: {}".format(x[0][0][0]))
# Add offset and scale with anchors
pred_boxes = torch.zeros(prediction[..., :4].shape)
pred_boxes = torch.stack((x.data+self.grid_x,y.data+self.grid_y,torch.exp(w.data)*self.anchor_w,torch.exp(h.data)*self.anchor_h),4)
output = torch.cat(
(
pred_boxes.view(num_samples, -1, 4) * self.stride,
pred_conf.view(num_samples, -1, 1),
pred_cls.view(num_samples, -1, self.num_classes),
),
-1,
)
#print(output[0][0][0])
# if targets is None:
# return output, 0
# else:
# iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(
# pred_boxes=pred_boxes,
# pred_cls=pred_cls,
# target=targets,
# anchors=self.scaled_anchors,
# ignore_thres=self.ignore_thres,
# )
# # Loss : Mask outputs to ignore non-existing objects (except with conf. loss)
# loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
# loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
# loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
# loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
# loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
# loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])
# loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
# loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask])
# total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls
# # Metrics
# cls_acc = 100 * class_mask[obj_mask].mean()
# conf_obj = pred_conf[obj_mask].mean()
# conf_noobj = pred_conf[noobj_mask].mean()
# conf50 = (pred_conf > 0.5).float()
# iou50 = (iou_scores > 0.5).float()
# iou75 = (iou_scores > 0.75).float()
# detected_mask = conf50 * class_mask * tconf
# precision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)
# recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)
# recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)
# self.metrics = {
# "loss": to_cpu(total_loss).item(),
# "x": to_cpu(loss_x).item(),
# "y": to_cpu(loss_y).item(),
# "w": to_cpu(loss_w).item(),
# "h": to_cpu(loss_h).item(),
# "conf": to_cpu(loss_conf).item(),
# "cls": to_cpu(loss_cls).item(),
# "cls_acc": to_cpu(cls_acc).item(),
# "recall50": to_cpu(recall50).item(),
# "recall75": to_cpu(recall75).item(),
# "precision": to_cpu(precision).item(),
# "conf_obj": to_cpu(conf_obj).item(),
# "conf_noobj": to_cpu(conf_noobj).item(),
# "grid_size": grid_size,
# }
#total_loss = 0
return output, torch.tensor(0)
`