Summary
I was following the “Writing your own quantized tensor” tutorial with apple silicon chip(M4 pro), but unlike tutorial example, the aten.linear.default is not decomposed into aten.mm.default, and aten.t.default.
Code
All same as linked tutorial. To log what functions float_model
encountered, I added LoggingTensor
following examples in torchao
github. The code I wrote is following:
from typing import Tuple, List, Any
import torch
import torch.utils._pytree as pytree
class LoggingTensor(torch.Tensor):
@staticmethod
@torch._dynamo.disable
def __new__(cls, a):
return torch.Tensor._make_wrapper_subclass(
cls,
a.shape,
strides = a.stride(),
storage_offset = a.storage_offset(),
dtype = a.dtype,
device = a.device,
)
@torch._dynamo.disable
def __init__(self, a):
self.a = a
def __tensor_flatten__(self) -> Tuple[List[str], Any]:
return ["a"], None
@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
assert extra_metadata is None
a = tensor_data_dict["a"]
return LoggingTensor(a)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = kwargs or {} # if kwargs is None, set it to an empty dictionary
print(f"func: {str(func)}")
# Unwrapping any LoggingTensor arguments.
# Calling the underlying function on the inner tensors
# Wrapping any tensor outputs into LoggingTensor
args_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, args)
kwargs_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, kwargs)
out_a = func(*args_a, **kwargs_a)
out_a_flat, spec = pytree.tree_flatten(out_a)
out_flat = [
cls(o_a) if isinstance(o_a, torch.Tensor) else o_a for o_a in out_a_flat
]
return pytree.tree_unflatten(out_flat, spec)
class ToyModel(torch.nn.Module):
def __init__(self, m: int, n: int, k: int):
super(ToyModel, self).__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False)
self.linear2 = torch.nn.Linear(n, k, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
return x
if __name__ == "__main__":
device = 'mps'
float_model = ToyModel(64, 128, 32).to(device) # Same as tutorial ToyModel.
for name, child in float_model.named_children():
if type(child) == torch.nn.Linear:
child.weight = torch.nn.Parameter(
LoggingTensor(child.weight), requires_grad=True
)
with torch.no_grad():
x = torch.randn(64, 64, 64).to(device)
_ = float_model(x)
The output is:
func: aten.detach.default
func: aten.detach.default
func: aten.detach.default
func: aten.detach.default
func: aten.linear.default
func: aten.linear.default
When I tried same tutorial and logging code in Google Colab with CUDA, it works as intended described in tutorial. So I think this is a problem related to MPS or CPU.
I’m thinking about modifying the code that registers ops to work with torch.ops.aten.linear.default
, but I am very new to this topic, I don’t know where to start from. So, could anyone recommend any related resources for modifying the register ops code in tutorial for MPS?