I am trying to use `functorch.compile.memory_efficient_fusion`

on a Mask R-CNN type model. While doing so, I ran into several problems about data dependent control flow. In particular, the error message is:

```
RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! It's likely that this is caused by data-dependent control flow or similar.
```

One interesting line of code causing the compilation to fail are as follows (The snippet below works with random data created)

```
import torch
num_gts = 14
num_anchors = 1000
assigned_gt_inds = torch.randint(-1, 10, size=(num_anchors,))
gt_max_overlaps = torch.rand((num_gts,))
min_pos_iou = 0.3
gt_max_assign_all = True
overlaps = torch.rand((num_gts, num_anchors))
gt_argmax_overlaps = torch.randint(0, num_anchors, size=(num_gts,))
for i in range(num_gts):
if gt_max_overlaps[i] >= min_pos_iou:
if gt_max_assign_all:
max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
assigned_gt_inds[max_iou_inds] = i + 1
else:
assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
```

How can I rewrite this snippet compliant with `functorch`

? At first, I thought I’d be looking into a double `torch.where`

but then I realized that this is not possible: `assigned_gt_inds`

needs to only to be changed if `gt_max_overlaps[i] >= min_pos_iou`

, otherwise `assigned_gt_inds`

remains unchanged. However, the position at which to change this tensor are dependent on the second condition. Please also not that `overlaps`

is a 2D tensor, so I’m not sure how to calculate this condition.

Does somebody have an idea how to rewrite the snippet above to work with functorch?

For reference, the code comes from mmdetection (mmdetection/max_iou_assigner.py at 3.x · open-mmlab/mmdetection · GitHub) and I would also reference a relevant github ticket for pytorch: Data-dependent control flow exploration · Issue #257 · pytorch/functorch · GitHub