Data Dependent Control Flow - Functorch

I am trying to use functorch.compile.memory_efficient_fusion on a Mask R-CNN type model. While doing so, I ran into several problems about data dependent control flow. In particular, the error message is:

RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! It's likely that this is caused by data-dependent control flow or similar.

One interesting line of code causing the compilation to fail are as follows (The snippet below works with random data created)

import torch
num_gts = 14
num_anchors = 1000
assigned_gt_inds = torch.randint(-1, 10, size=(num_anchors,))
gt_max_overlaps = torch.rand((num_gts,))
min_pos_iou = 0.3
gt_max_assign_all = True
overlaps = torch.rand((num_gts, num_anchors))
gt_argmax_overlaps = torch.randint(0, num_anchors, size=(num_gts,))
for i in range(num_gts):
    if gt_max_overlaps[i] >= min_pos_iou:
        if gt_max_assign_all:
            max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
            assigned_gt_inds[max_iou_inds] = i + 1
        else:
            assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1

How can I rewrite this snippet compliant with functorch? At first, I thought I’d be looking into a double torch.where but then I realized that this is not possible: assigned_gt_inds needs to only to be changed if gt_max_overlaps[i] >= min_pos_iou, otherwise assigned_gt_inds remains unchanged. However, the position at which to change this tensor are dependent on the second condition. Please also not that overlaps is a 2D tensor, so I’m not sure how to calculate this condition.

Does somebody have an idea how to rewrite the snippet above to work with functorch?

For reference, the code comes from mmdetection (mmdetection/max_iou_assigner.py at 3.x · open-mmlab/mmdetection · GitHub) and I would also reference a relevant github ticket for pytorch: Data-dependent control flow exploration · Issue #257 · pytorch/functorch · GitHub