How to gracefully mask CompositeImplicitAutograd for different backends
Background:
I implemented torch.compile’s backend for my hardware via privateUserOne. I also found that torch.compile by default decomposes upsample_nearest2d
into a bunch of small operators, just like _upsample_nearest
does. But on my hardware, the _unsafe_index
operator doesn’t perform well, so I’d like to be able to call the custom upsample_nearest2d
operator directly for better performance.
@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d(
input: Tensor,
output_size: List[int],
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> Tensor:
return _upsample_nearest(input, output_size, [scales_h, scales_w])
@pw_cast_for_opmath
def _upsample_nearest(
input: Tensor,
output_size: List[int],
scales: List[Optional[float]],
exact: bool = False,
) -> Tensor:
spatial_indices = _compute_upsample_nearest_indices(
input, output_size, scales, exact=exact
)
indices = [None, None] + spatial_indices
result = aten._unsafe_index(input, indices)
if result.ndim == 4:
# convert output to correct memory format, if necessary
memory_format = utils.suggest_memory_format(input)
# following "heuristic: only use channels_last path when it's faster than the contiguous path"
n_channels = input.shape[1]
if input.device.type == "cuda" and n_channels < 4:
memory_format = torch.contiguous_format
result = result.contiguous(memory_format=memory_format)
return result
Attempted Resolution:
I try to cancel DispatchKey.CompositeImplicitAutograd
of upsample_nearest2d
, like:
def disable_implicit_decomposition():
'''
Since torch official will implicitly decompose some aten ops,
disable some ops here to avoid poor performance after decompose.
'''
disable_aten_ops = [
'aten.upsample_nearest1d.vec', 'aten.upsample_nearest1d.default',
'aten.upsample_nearest2d.vec', 'aten.upsample_nearest2d.default',
'aten.upsample_nearest3d.vec', 'aten.upsample_nearest3d.default',
]
for op_override in decomposition_table.keys():
if str(op_override) in disable_aten_ops:
if DispatchKey.Autograd in op_override.py_kernels:
op_override.py_kernels.pop(DispatchKey.Autograd)
if DispatchKey.CompositeImplicitAutograd in op_override.py_kernels:
op_override.py_kernels.pop(DispatchKey.CompositeImplicitAutograd)
However, this modification causes the upsample_nearest2d
operator to be dispatched to the C++ implementation of upsample_nearest2d
. The input.sizes ()
function causes an error when torch.compile (dynamic=True) is executed.
Tensor upsample_nearest2d(
const Tensor& input,
at::OptionalIntArrayRef output_size,
std::optional<ArrayRef<double>> scale_factors) {
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
auto scale_h = get_scale_value(scale_factors, 0);
auto scale_w = get_scale_value(scale_factors, 1);
return at::upsample_nearest2d(input, osize, scale_h, scale_w);
}
raise error:
Cannot call sizes() on tensor with symbolic sizes/strides
Questions and Discussion Points
1、Why add DispatchKey.CompositeImplicitAutograd
for upsample_nearest2d
? The root cause is that you can’t do input.size () in c++ implementation?
2、Why does the upsample_nearest2d
function need to call the _upsample_nearest
and decomposes to a bunch of small operators ? The problem I encountered in this scenario can also be solved by directly calling torch.ops.aten.upsample_nearest2d
in upsample_nearest2d
.
@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d(
input: Tensor,
output_size: List[int],
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> Tensor:
# return _upsample_nearest(input, output_size, [scales_h, scales_w])
return torch.ops.aten.upsample_nearest2d(input, output_size, scales_h, scales_w)
3、Is there a more appropriate way to achieve my goal without splitting the operator during torch.compile?