How to register a pytree for an ouptut class?

How can I register a pytree for a class so that I can apply the functorch’s memory_efficient_fusion?

For Mask R-CNN type predictions a sampling result is calculated during the forward pass. The sampling results is wrapped inside the following class:

class SamplingResult:

    def __init__(self, pos_inds: Tensor, neg_inds: Tensor):
        self.pos_inds = pos_inds
       self.neg_inds = neg_inds

During the compilations phase for functorch, I receive the following error:

RuntimeError: Found <class 'multimodal.models.middle_networks.random_sampler.SamplingResult'> in output, which is not a known type. If this type holds tensors, you need to register a pytree for it. See https://github.com/pytorch/functorch/issues/475 for a brief explanation why. If you don't need to register a pytree, please leave a comment explaining your use case and we'll make this more ergonomic to deal with

I looked at the linked issue but I couldn’t figure out how to create a pytree from this. Is there some documentation about it?

I figures out a way to do it. For reference, the jax docs for Custom pytree nodes turned out to be very helpful indeed: Working with Pytrees — JAX documentation