A new way to run and support autograd with meta tensors

When I tried to run autograd with meta tensor input on vit_b_16, I discovered that some aten ops are not registered for meta backend. So following the suggestions in Function to automatically calculate Conv shape · Issue #79512 · pytorch/pytorch · GitHub, I tried to patch native_layer_norm.default for meta backend.

@register_meta(aten.native_layer_norm.default)
def meta_ln(
    input: torch.Tensor, 
    normalized_shape, weight, bias, eps
):
    n_input = input.size(1)

    output = torch.empty_like(input)
    running_mean = torch.empty((n_input), device='meta')
    running_var = torch.empty((n_input), device='meta')
    return output, running_mean, running_var

@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(
    dY: torch.Tensor,
    input: torch.Tensor, 
    normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
    dX = torch.empty_like(input)
    dgamma = torch.empty_like(weight)
    dbeta = torch.empty_like(bias)
    return dX, dgamma, dbeta

However, even if patching is successful, the autograd dispatcher refuses to use my patched op for meta backend.

RuntimeError: 0 INTERNAL ASSERT FAILED at "../aten/src/ATen/core/boxing/KernelFunction.cpp":23, please report a bug to PyTorch. aten::native_layer_norm has kernels registered to both CompositeImplicitAutograd and a backend mapped to AutogradOther. This makes the backend kernel unreachable; the dispatcher will always prefer the CompositeImplicitAutograd lowering (see Note [Ambiguity in AutogradOther kernel]). If you want to override CompositeImplicitAutograd, please open an issue to request a dedicated Autograd dispatch key for the backend.
If you only want to run inference instead of training, add `c10::InferenceMode mode;` before model.forward(). Note this guard is only available in C++ but not Python at present.

So as discussed in CompositeImplicitAutograd operators should not perform operations that do not dispatch · Issue #61669 · pytorch/pytorch · GitHub, failing due to CompositeImplicitAutograd is inevitable for PyTorch version 1.12.0 and below. I somehow managed to develop another version of autograd with meta tensor.

class MetaTensor(torch.Tensor):

    elem: torch.Tensor
 
    __slots__ = ['elem']
 
    @staticmethod
    def __new__(cls, elem):
        r = torch.Tensor._make_wrapper_subclass(
            cls, elem.size(),
            strides=elem.stride(), storage_offset=elem.storage_offset(),
            dtype=elem.dtype, layout=elem.layout,
            device='cpu', requires_grad=elem.requires_grad
        )    # deceive the frontend for aten selections
        r.elem = elem
        return r

    @ classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(x):
            return x.elem.to('meta') if isinstance(x, MetaTensor) else x
                
        args = tree_map(unwrap, args)
        kwargs = tree_map(unwrap, kwargs)
        out = func(*args, **kwargs)
        
        def wrap(x):
            return MetaTensor(x) if isinstance(x, torch.Tensor) else x
           
        return tree_map(wrap, out)

Since previous works of the PyTorch team have supported aten ops on meta backend, we can simply hack the autograd dispatcher, deceiving it that we are running on CPU. In this way, the dispatcher will not use CompositeImplicitAutograd anymore, and our patched ops can be used for meta backend. So now we can do forward and backward with meta tensor only, and trace a large model with batch_size=1e10 in milliseconds.

model = vit_b_16()
data = MetaTensor(torch.rand(int(1e10), 3, 224, 224, device='meta'))
model.to('meta')(data).sum().backward()

Since I only tested consistency with PyTorch version 1.12.0, I am curious about the extensibility of this code. I want to hear your suggestions about this implementation.