RuntimeError: derivative for aten::grid_sampler_3d_backward is not implemented

Hi, I was trying threestudio with 3D volume grid and I ran into this issue, I found a similar issue (RuntimeError: derivative for grid_sampler_2d_backward is not implemented · Issue #34704 · pytorch/pytorch · GitHub) for grid_sampler_2d_backward but no solutions yet. I wonder if anyone have any ideas?

Pytorch version: 2.1.0+cu121

The original error

File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit                                      
    call._call_and_handle_interrupt(                                                                                                                                       
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt                   
    return trainer_fn(*args, **kwargs)                                                                                                                                     
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl                                
    self._run(model, ckpt_path=ckpt_path)                                                                                                                                  
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run                                     
    results = self._run_stage()                                                                                                                                            
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage                              
    self.fit_loop.run()                   
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run                                       
    self.advance()                        
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance                                   
    self.epoch_loop.run(self._data_fetcher)                                          
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run                            
    self.advance(data_fetcher)            
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance                        
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)                                                                               
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run                         
    self._optimizer_step(batch_idx, closure)                                         
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(                                                
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 159, in _call_lightning_module_hook
    output = fn(*args, **kwargs)          
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1308, in optimizer_step                              
    optimizer.step(closure=optimizer_closure)                                        
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 153, in step                                      
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)                                                                                        
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 238, in optimizer_step                       
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)                                                                         
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/amp.py", line 77, in optimizer_step                      
    closure_result = closure()            
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)                                     
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context                                   
    return func(*args, **kwargs)          
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 138, in closure
    self._backward_fn(step_output.closure_loss) 
File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 212, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
    model.backward(tensor, *args, **kwargs)
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1103, in backward
    loss.backward(*args, **kwargs)
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/mnt/localssd/miniconda3/envs/threestudio/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: derivative for aten::grid_sampler_3d_backward is not implemented                                     

It seems to work for me using the example from another user:

input = torch.tensor([[
    [[1., 2., 3.],
     [4., 5., 6.]]]],
    dtype=torch.float64,
    requires_grad=True)

# 2x3x2
grid = torch.tensor([[  # x,y
    [[ 1.,  1.],  # 1
     [-1., -1.],  # 6
     [ 0.,  1.]], # 5

    [[ 0.,  0.],     # between 2 and 5   == (2 + 5) / 2 == 3.5
     [ 0.,  0.5],    # between 3.5 and 5 == (3.5 + 5) / 2 == 4.25
     [-1., -1.]]]],  # etc
     dtype=torch.float64,
     requires_grad=True)

interpolation_mode = 'bilinear'
padding_mode = 'zeros'
align_corners = True

res = torch.nn.functional.grid_sample(input, grid, interpolation_mode, padding_mode, align_corners)
res.mean().backward()
print(input.grad)
# tensor([[[[0.3333, 0.1250, 0.0000],
#           [0.0000, 0.3750, 0.1667]]]], dtype=torch.float64)

Could you post a minimal and executable code snippet reproducing the error?

Similar issue with minimal example code attached: