How can I register a pytree for a class so that I can apply the functorch’s `memory_efficient_fusion`

?

For Mask R-CNN type predictions a sampling result is calculated during the forward pass. The sampling results is wrapped inside the following class:

```
class SamplingResult:
def __init__(self, pos_inds: Tensor, neg_inds: Tensor):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
```

During the compilations phase for functorch, I receive the following error:

```
RuntimeError: Found <class 'multimodal.models.middle_networks.random_sampler.SamplingResult'> in output, which is not a known type. If this type holds tensors, you need to register a pytree for it. See https://github.com/pytorch/functorch/issues/475 for a brief explanation why. If you don't need to register a pytree, please leave a comment explaining your use case and we'll make this more ergonomic to deal with
```

I looked at the linked issue but I couldn’t figure out how to create a pytree from this. Is there some documentation about it?