How custom cross_entropy_loss register to PrivateUse1 and support backward

I add a new backend to pytorch, refer to Extending dispatcher for a new backend in C++ — PyTorch Tutorials 2.0.0+cu117 documentation. The new backend adopts the reserved dispatch keys PrivateUse1/AutogradPrivateUse1. Now I have registered some operators to pytorch, such as add.Tensor, add.out, etc., but there are some problems when registering the cross_entropy_loss operator.

Operator registration code is as follows

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
    m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}

The test code named test_cross_entropy_loss.py is as follows

import torch
import torch.nn.functional as F
# new backend python module, like torch_xla
import torch_dipu

input = torch.randn(3, 5)
target = torch.randn(3, 5).softmax(dim=1)
device = torch.device("privateuseone")
input = input.to(device)
input.requires_grad_(True)
target = target.to(device)

loss = F.cross_entropy(input, target)
print(f"loss = {loss}")

loss.backward()
print(f"input.grad = {input.grad}")

When I run test_cross_entropy_loss.py, I got the following error

Traceback (most recent call last):
  File "test_cross_entropy_loss.py", line 12, in <module>
    loss = F.cross_entropy(input, target)
  File "pytorch/torch/nn/functional.py", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
NotImplementedError: Could not run 'aten::cross_entropy_loss' with arguments from the 'AutogradPrivateUse1' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::cross_entropy_loss' is only available for these backends: [CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, VE, Lazy, Meta, MTIA, PrivateUse1, PrivateUse2, PrivateUse3, FPGA, ORT, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedXLA, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedLazy, QuantizedMeta, QuantizedMTIA, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, CustomRNGKeyId, MkldnnCPU, SparseCPU, SparseCUDA, SparseHIP, SparseXLA, SparseMPS, SparseIPU, SparseXPU, SparseHPU, SparseVE, SparseLazy, SparseMeta, SparseMTIA, SparsePrivateUse1, SparsePrivateUse2, SparsePrivateUse3, SparseCsrCPU, SparseCsrCUDA, NestedTensorCPU, NestedTensorCUDA, NestedTensorHIP, NestedTensorXLA, NestedTensorMPS, NestedTensorIPU, NestedTensorXPU, NestedTensorHPU, NestedTensorVE, NestedTensorLazy, NestedTensorMeta, NestedTensorMTIA, NestedTensorPrivateUse1, NestedTensorPrivateUse2, NestedTensorPrivateUse3, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

So I changed the registration code to

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
    m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}

I re-run the code, the forward process can run successfully, and the backward process reports an error

Traceback (most recent call last):
  File "test_cross_entropy_loss.py", line 15, in <module>
    loss.backward()
  File "pytorch/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "pytorch/torch/autograd/__init__.py", line 199, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I have two questions:

  1. Why cross_entropy_loss registration to PrivateUse1 will be wrong
  2. When cross_entropy_loss is registered to AutogradPrivateUse1, why does backward go wrong, and what should I do to implement custom cross_entropy_loss backward operation

I guess it may be related to the registration of the cross_entropy_loss operator to CompositeImplicitAutograd, but I haven’t figured out the connection

cross_entropy_loss is a CompositeImplicitAutograd op, which means that there’s no autograd formula registered for that. Instead the operator will decompose into a series of more primitive operations.

If you register a custom backend kernel for cross_entropy_loss, it doesn’t make sense to use that decomposition anymore, because that means that you’re custom kernel wouldn’t be run.

So for this case, you’d need to register BOTH a kernel to PrivateUse1 and AutogradPrivateUse1.

2 Likes

Conversely the reason why add.Tensor works is because they are primitive ops, i.e. they are not CompositeImplicitAutograd, and so rather than having a decomposition, there would be a default autograd implementation available for you to use, and that is why you only needed to register a kernel to PrivateUse1 in that case.

1 Like

Thank you for your reply.

According to your opinion, I register my custom cross_entropy_loss to two dispatcher keys, i.e. PrivateUse1 and AutogradPrivateUse1.

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
    m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
    m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}

Then I re-run the test code, input.grad becomes None after backpropagation. It looks like the backpropagation logic is incorrect. What should I do to implement a custom cross_entropy_loss backward op and embed it into pytorch autograd? Or pytorch does not allow customize CompositeImplicitAutograd op implementation and can only customize primitive ops like nll_loss to implement custom cross_entropy_loss?

@ryanc what makes this more challenging is that cross_entropy loss has no derivative formula. If you want to implement a custom kernel for cross_entropy_loss, and you want autograd to work, then you’re on the hook for implementing s derivative formula to work too.

As soulitzer said, the reason cross_entropy_loss works successfully today is because it doesn’t have a dedicated kernel implementation: instead, it has a decomposition, where it decomposes into other ATen operators, that each have their own dedicated autograd formula. One option would be to check the decomposition, and write custom kernels for each of the primitive ops in the decomposition instead.

1 Like

Thanks for your reply!

I finally solved this problem, my solution refer to Autograd in C++ Frontend — PyTorch Tutorials 1.8.1+cu102 documentation. The core code is as follows.

registration code

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
    m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
    m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}

implement the forward and backward process

class CrossEntropyLossFunction : public torch::autograd::Function<CrossEntropyLossFunction> {
public:
    static at::Tensor forward(
        AutogradContext *ctx, const at::Tensor &self, const at::Tensor &target,
            const c10::optional<at::Tensor> &weight_opt, int64_t reduction,
            c10::SymInt ignore_index, double label_smoothing) {
        ctx->saved_data["reduction"] = reduction;
        ctx->saved_data["ignore_index"] = ignore_index.expect_int();
        ctx->saved_data["label_smoothing"] = label_smoothing;
        bool weight_has_value = weight_opt.has_value();
        ctx->saved_data["weight_has_value"] = weight_has_value;

        at::AutoDispatchBelowADInplaceOrView g;
        if (!weight_has_value) {
            ctx->save_for_backward({self, target});
        } else {
            ctx->save_for_backward({self, target, weight_opt.value()});
        }

        at::Tensor result;
        // do the compute and get result
        return result;
    }

  static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
      auto reduction = ctx->saved_data["reduction"].toInt();
      auto ignore_index = ctx->saved_data["ignore_index"].toInt();
      auto label_smoothing = ctx->saved_data["label_smoothing"].toDouble();
      auto weight_has_value = ctx->saved_data["weight_has_value"].toBool();
      auto saved = ctx->get_saved_variables();
      auto input = saved[0];
      auto target = saved[1];
      c10::optional<at::Tensor> weight = c10::nullopt;
      if (weight_has_value) {
        weight.emplace(saved[2]);
      }

      at::Tensor grad_input;
      // do the compute and get grad_input
      return {grad_input, at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor(), at::Tensor()};
  }
};

at::Tensor custom_cross_entropy_loss(const at::Tensor& self, const at::Tensor& target, const c10::optional<at::Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index, double label_smoothing) {
    c10::optional<at::Tensor> weight = c10::nullopt;
    if (weight_opt.has_value() && weight_opt.value().defined()) {
        weight = weight_opt;
    }

    return CrossEntropyLossFunction::apply(self, target, weight, reduction, ignore_index, label_smoothing);
}