I add a new backend to pytorch, refer to Extending dispatcher for a new backend in C++ — PyTorch Tutorials 2.0.0+cu117 documentation. The new backend adopts the reserved dispatch keys PrivateUse1/AutogradPrivateUse1. Now I have registered some operators to pytorch, such as add.Tensor, add.out, etc., but there are some problems when registering the cross_entropy_loss operator.
Operator registration code is as follows
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}
The test code named test_cross_entropy_loss.py is as follows
import torch
import torch.nn.functional as F
# new backend python module, like torch_xla
import torch_dipu
input = torch.randn(3, 5)
target = torch.randn(3, 5).softmax(dim=1)
device = torch.device("privateuseone")
input = input.to(device)
input.requires_grad_(True)
target = target.to(device)
loss = F.cross_entropy(input, target)
print(f"loss = {loss}")
loss.backward()
print(f"input.grad = {input.grad}")
When I run test_cross_entropy_loss.py, I got the following error
Traceback (most recent call last):
File "test_cross_entropy_loss.py", line 12, in <module>
loss = F.cross_entropy(input, target)
File "pytorch/torch/nn/functional.py", line 3029, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
NotImplementedError: Could not run 'aten::cross_entropy_loss' with arguments from the 'AutogradPrivateUse1' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::cross_entropy_loss' is only available for these backends: [CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, VE, Lazy, Meta, MTIA, PrivateUse1, PrivateUse2, PrivateUse3, FPGA, ORT, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedXLA, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedLazy, QuantizedMeta, QuantizedMTIA, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, CustomRNGKeyId, MkldnnCPU, SparseCPU, SparseCUDA, SparseHIP, SparseXLA, SparseMPS, SparseIPU, SparseXPU, SparseHPU, SparseVE, SparseLazy, SparseMeta, SparseMTIA, SparsePrivateUse1, SparsePrivateUse2, SparsePrivateUse3, SparseCsrCPU, SparseCsrCUDA, NestedTensorCPU, NestedTensorCUDA, NestedTensorHIP, NestedTensorXLA, NestedTensorMPS, NestedTensorIPU, NestedTensorXPU, NestedTensorHPU, NestedTensorVE, NestedTensorLazy, NestedTensorMeta, NestedTensorMTIA, NestedTensorPrivateUse1, NestedTensorPrivateUse2, NestedTensorPrivateUse3, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].
So I changed the registration code to
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl("cross_entropy_loss", custom_cross_entropy_loss);
}
I re-run the code, the forward process can run successfully, and the backward process reports an error
Traceback (most recent call last):
File "test_cross_entropy_loss.py", line 15, in <module>
loss.backward()
File "pytorch/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "pytorch/torch/autograd/__init__.py", line 199, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I have two questions:
- Why cross_entropy_loss registration to PrivateUse1 will be wrong
- When cross_entropy_loss is registered to AutogradPrivateUse1, why does backward go wrong, and what should I do to implement custom cross_entropy_loss backward operation
I guess it may be related to the registration of the cross_entropy_loss operator to CompositeImplicitAutograd, but I haven’t figured out the connection