I am trying to use
functorch.compile.memory_efficient_fusion on a Mask R-CNN type model. While doing so, I ran into several problems about data dependent control flow. In particular, the error message is:
RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! It's likely that this is caused by data-dependent control flow or similar.
One interesting line of code causing the compilation to fail are as follows (The snippet below works with random data created)
import torch num_gts = 14 num_anchors = 1000 assigned_gt_inds = torch.randint(-1, 10, size=(num_anchors,)) gt_max_overlaps = torch.rand((num_gts,)) min_pos_iou = 0.3 gt_max_assign_all = True overlaps = torch.rand((num_gts, num_anchors)) gt_argmax_overlaps = torch.randint(0, num_anchors, size=(num_gts,)) for i in range(num_gts): if gt_max_overlaps[i] >= min_pos_iou: if gt_max_assign_all: max_iou_inds = overlaps[i, :] == gt_max_overlaps[i] assigned_gt_inds[max_iou_inds] = i + 1 else: assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
How can I rewrite this snippet compliant with
functorch? At first, I thought I’d be looking into a double
torch.where but then I realized that this is not possible:
assigned_gt_inds needs to only to be changed if
gt_max_overlaps[i] >= min_pos_iou, otherwise
assigned_gt_inds remains unchanged. However, the position at which to change this tensor are dependent on the second condition. Please also not that
overlaps is a 2D tensor, so I’m not sure how to calculate this condition.
Does somebody have an idea how to rewrite the snippet above to work with functorch?
For reference, the code comes from mmdetection (mmdetection/max_iou_assigner.py at 3.x · open-mmlab/mmdetection · GitHub) and I would also reference a relevant github ticket for pytorch: Data-dependent control flow exploration · Issue #257 · pytorch/functorch · GitHub