How to gracefully mask CompositeImplicitAutograd for different backends

How to gracefully mask CompositeImplicitAutograd for different backends


I implemented torch.compile’s backend for my hardware via privateUserOne. I also found that torch.compile by default decomposes upsample_nearest2d into a bunch of small operators, just like _upsample_nearest does. But on my hardware, the _unsafe_index operator doesn’t perform well, so I’d like to be able to call the custom upsample_nearest2d operator directly for better performance.

@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d(
    input: Tensor,
    output_size: List[int],
    scales_h: Optional[float] = None,
    scales_w: Optional[float] = None,
) -> Tensor:
    return _upsample_nearest(input, output_size, [scales_h, scales_w])
def _upsample_nearest(
    input: Tensor,
    output_size: List[int],
    scales: List[Optional[float]],
    exact: bool = False,
) -> Tensor:
    spatial_indices = _compute_upsample_nearest_indices(
        input, output_size, scales, exact=exact

    indices = [None, None] + spatial_indices
    result = aten._unsafe_index(input, indices)

    if result.ndim == 4:
        # convert output to correct memory format, if necessary
        memory_format = utils.suggest_memory_format(input)

        # following "heuristic: only use channels_last path when it's faster than the contiguous path"
        n_channels = input.shape[1]
        if input.device.type == "cuda" and n_channels < 4:
            memory_format = torch.contiguous_format

        result = result.contiguous(memory_format=memory_format)
    return result

Attempted Resolution:

I try to cancel DispatchKey.CompositeImplicitAutograd of upsample_nearest2d, like:

def disable_implicit_decomposition():
    Since torch official will implicitly decompose some aten ops,
    disable some ops here to avoid poor performance after decompose.
    disable_aten_ops = [
        'aten.upsample_nearest1d.vec', 'aten.upsample_nearest1d.default',
        'aten.upsample_nearest2d.vec', 'aten.upsample_nearest2d.default',
        'aten.upsample_nearest3d.vec', 'aten.upsample_nearest3d.default',

    for op_override in decomposition_table.keys():
        if str(op_override) in disable_aten_ops:
            if DispatchKey.Autograd in op_override.py_kernels:
            if DispatchKey.CompositeImplicitAutograd in op_override.py_kernels:

However, this modification causes the upsample_nearest2d operator to be dispatched to the C++ implementation of upsample_nearest2d. The input.sizes () function causes an error when torch.compile (dynamic=True) is executed.

Tensor upsample_nearest2d(
    const Tensor& input,
    at::OptionalIntArrayRef output_size,
    std::optional<ArrayRef<double>> scale_factors) {
  auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
  auto scale_h = get_scale_value(scale_factors, 0);
  auto scale_w = get_scale_value(scale_factors, 1);
  return at::upsample_nearest2d(input, osize, scale_h, scale_w);

raise error:

Cannot call sizes() on tensor with symbolic sizes/strides

Questions and Discussion Points

1、Why add DispatchKey.CompositeImplicitAutograd for upsample_nearest2d? The root cause is that you can’t do input.size () in c++ implementation?

2、Why does the upsample_nearest2d function need to call the _upsample_nearest and decomposes to a bunch of small operators ? The problem I encountered in this scenario can also be solved by directly calling torch.ops.aten.upsample_nearest2d in upsample_nearest2d.

@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d(
    input: Tensor,
    output_size: List[int],
    scales_h: Optional[float] = None,
    scales_w: Optional[float] = None,
) -> Tensor:
    # return _upsample_nearest(input, output_size, [scales_h, scales_w])
    return torch.ops.aten.upsample_nearest2d(input, output_size, scales_h, scales_w)

3、Is there a more appropriate way to achieve my goal without splitting the operator during torch.compile?