How can I register a pytree for a class so that I can apply the functorch’s memory_efficient_fusion
?
For Mask R-CNN type predictions a sampling result is calculated during the forward pass. The sampling results is wrapped inside the following class:
class SamplingResult:
def __init__(self, pos_inds: Tensor, neg_inds: Tensor):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
During the compilations phase for functorch, I receive the following error:
RuntimeError: Found <class 'multimodal.models.middle_networks.random_sampler.SamplingResult'> in output, which is not a known type. If this type holds tensors, you need to register a pytree for it. See https://github.com/pytorch/functorch/issues/475 for a brief explanation why. If you don't need to register a pytree, please leave a comment explaining your use case and we'll make this more ergonomic to deal with
I looked at the linked issue but I couldn’t figure out how to create a pytree from this. Is there some documentation about it?