I am trying to use functorch.compile.memory_efficient_fusion
on a Mask R-CNN type model. While doing so, I ran into several problems about data dependent control flow. In particular, the error message is:
RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! It's likely that this is caused by data-dependent control flow or similar.
One interesting line of code causing the compilation to fail are as follows (The snippet below works with random data created)
import torch
num_gts = 14
num_anchors = 1000
assigned_gt_inds = torch.randint(-1, 10, size=(num_anchors,))
gt_max_overlaps = torch.rand((num_gts,))
min_pos_iou = 0.3
gt_max_assign_all = True
overlaps = torch.rand((num_gts, num_anchors))
gt_argmax_overlaps = torch.randint(0, num_anchors, size=(num_gts,))
for i in range(num_gts):
if gt_max_overlaps[i] >= min_pos_iou:
if gt_max_assign_all:
max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
assigned_gt_inds[max_iou_inds] = i + 1
else:
assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
How can I rewrite this snippet compliant with functorch
? At first, I thought I’d be looking into a double torch.where
but then I realized that this is not possible: assigned_gt_inds
needs to only to be changed if gt_max_overlaps[i] >= min_pos_iou
, otherwise assigned_gt_inds
remains unchanged. However, the position at which to change this tensor are dependent on the second condition. Please also not that overlaps
is a 2D tensor, so I’m not sure how to calculate this condition.
Does somebody have an idea how to rewrite the snippet above to work with functorch?
For reference, the code comes from mmdetection (mmdetection/max_iou_assigner.py at 3.x · open-mmlab/mmdetection · GitHub) and I would also reference a relevant github ticket for pytorch: Data-dependent control flow exploration · Issue #257 · pytorch/functorch · GitHub