[MPS] When device='mps', aten.linear.default op is not decomposed

Summary

I was following the “Writing your own quantized tensor” tutorial with apple silicon chip(M4 pro), but unlike tutorial example, the aten.linear.default is not decomposed into aten.mm.default, and aten.t.default.

Code

All same as linked tutorial. To log what functions float_model encountered, I added LoggingTensor following examples in torchao github. The code I wrote is following:

from typing import Tuple, List, Any

import torch
import torch.utils._pytree as pytree


class LoggingTensor(torch.Tensor):
    @staticmethod
    @torch._dynamo.disable
    def __new__(cls, a):
        return torch.Tensor._make_wrapper_subclass(
            cls,
            a.shape,
            strides = a.stride(),
            storage_offset = a.storage_offset(),
            dtype = a.dtype,
            device = a.device,
        )
    
    @torch._dynamo.disable
    def __init__(self, a):
        self.a = a


    def __tensor_flatten__(self) -> Tuple[List[str], Any]:
        return ["a"], None
    
    @classmethod
    def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
        assert extra_metadata is None
        a = tensor_data_dict["a"]
        return LoggingTensor(a)
    
    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        kwargs = kwargs or {} # if kwargs is None, set it to an empty dictionary
        print(f"func: {str(func)}")

        # Unwrapping any LoggingTensor arguments.
        # Calling the underlying function on the inner tensors
        # Wrapping any tensor outputs into LoggingTensor
        args_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, args)
        kwargs_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, kwargs)
        out_a = func(*args_a, **kwargs_a)
        out_a_flat, spec = pytree.tree_flatten(out_a)
        out_flat = [
            cls(o_a) if isinstance(o_a, torch.Tensor) else o_a for o_a in out_a_flat
        ]
        return pytree.tree_unflatten(out_flat, spec)


class ToyModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super(ToyModel, self).__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = self.linear2(x)
        return x

if __name__ == "__main__":
    device = 'mps'
    float_model = ToyModel(64, 128, 32).to(device) # Same as tutorial ToyModel.
      
    for name, child in float_model.named_children():
        if type(child) == torch.nn.Linear:
            child.weight = torch.nn.Parameter(
                LoggingTensor(child.weight), requires_grad=True
                )
      
      
    with torch.no_grad():
        x = torch.randn(64, 64, 64).to(device)
        _ = float_model(x)

The output is:

func: aten.detach.default
func: aten.detach.default
func: aten.detach.default
func: aten.detach.default
func: aten.linear.default
func: aten.linear.default

When I tried same tutorial and logging code in Google Colab with CUDA, it works as intended described in tutorial. So I think this is a problem related to MPS or CPU.

I’m thinking about modifying the code that registers ops to work with torch.ops.aten.linear.default, but I am very new to this topic, I don’t know where to start from. So, could anyone recommend any related resources for modifying the register ops code in tutorial for MPS?

I don’t know this is the right solution for achieving quantization benefits, but it works.

registering ops part replaced with

from torch.utils._python_dispatch import return_and_correct_aliasing

@register_op([
    torch.ops.aten.detach.default,
])
def int8_view_ops(func, *args, **kwargs):
    assert isinstance(args[0], Int8SymmetricTensor)
    out_data = func(args[0].int_data, *args[1:], **kwargs)
    out_scale = func(args[0].scale, *args[1:], **kwargs)
    out = Int8SymmetricTensor(out_data, out_scale)
    return return_and_correct_aliasing(func, args, kwargs, out)


@register_op([torch.ops.aten.linear.default])
def int8_linear(func, x, weight):
    # args[0] is input tensor, args[1] is weight tensor.
    # args[1] is Int8SymmetricTensor
    assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: linear currently only support the weight in low precision, not the input!"
    out_data = func(x, weight.int_data.to(x.dtype)) * weight.scale.T.to(x.dtype)
    return return_and_correct_aliasing(func, (x, weight), {}, out_data)

and, when comparing output, the torch.compile works when mode='max-autotune'.

float_model = ToyModel(64, 128, 32).to(device)
quantized_model_module_swap = copy.deepcopy(float_model)
quantized_model_subclass = copy.deepcopy(float_model)

# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model_module_swap.named_children():
    if type(child) == torch.nn.Linear:
        new_linear = QuantizedLinear.from_float(child)
        setattr(quantized_model_module_swap, name, new_linear)

# Swap torch.nn.Linear with Int8SymmetricTensor
for name, child in quantized_model_subclass.named_children():
    if type(child) == torch.nn.Linear:
        subclass_param = Int8SymmetricTensor.from_float(child.weight)
        child.weight = torch.nn.Parameter(subclass_param, requires_grad=True)

with torch.no_grad():
    x = torch.randn(64, 64, 64).to(device)
    out_module_swap = quantized_model_module_swap(x)
    out_subclass = quantized_model_subclass(x)
    print(torch.allclose(out_subclass, out_module_swap))

    out_compiled = torch.compile(quantized_model_subclass, mode="max-autotune")(x)
    print(torch.allclose(out_subclass, out_compiled))

output:

True
True