Bug report: autograd.grad, AdaptiveAvgPool3d, CUDA

There seems to be a bug in AdaptiveAvgPool3d on cuda.
Can someone confirm this with a different computer/PyTorch version?

The bug occurs when trying to compute autograd.grad of AdaptiveAvgPool3d. 1d and 2d works fine. My torch.__version__: ‘1.8.1+cu102’
To reproduce you can run the following code:

import torch
from torch import Tensor
from torch.nn import AdaptiveAvgPool3d

module: AdaptiveAvgPool3d = AdaptiveAvgPool3d(output_size=2)
inputs: Tensor = torch.rand(size=(1, 1, 2, 2, 4))
inputs.requires_grad = True
output_cpu: Tensor = module(inputs)
mat: Tensor = torch.rand_like(output_cpu)
derivative_cpu_torch: Tensor = torch.autograd.grad(output_cpu, inputs, mat)[0]

module.to(device=1)
inputs = inputs.to(device="cuda")
mat = mat.to(device="cuda")
output_cuda: Tensor = module(inputs)
derivative_cuda_torch: Tensor = torch.autograd.grad(output_cuda, inputs, mat)[0]

print("Outputs identical?", torch.allclose(output_cpu, output_cuda.to(device="cpu")))
print(
    "Torch derivatives (cuda+cpu) match?",
    torch.allclose(derivative_cpu_torch, derivative_cuda_torch.to(device="cpu")),
)
print("cpu derivative torch:", derivative_cpu_torch)
print("cuda derivative torch:", derivative_cuda_torch)

Has the following output for me (no fixed seed):

Outputs identical? True
Torch derivatives (cuda+cpu) match? False
cpu derivative torch: tensor([[[[[0.3119, 0.3119, 0.3913, 0.3913],
           [0.3091, 0.3091, 0.0067, 0.0067]],

          [[0.1353, 0.1353, 0.1326, 0.1326],
           [0.2254, 0.2254, 0.4297, 0.4297]]]]])
cuda derivative torch: tensor([[[[[0.3119, 0.3119, 0.3913, 0.3913],
           [0.1353, 0.1353, 0.1326, 0.1326]],

          [[0.1353, 0.1353, 0.1326, 0.1326],
           [0.0000, 0.0000, 0.0000, 0.0000]]]]], device='cuda:0')

Could you update PyTorch to the latest release (1.9.0) or the nightly and rerun the script? In case you are still seeing issues, could you post the output of python -m torch.utils.collect_env, please?

I’ve upgraded PyTorch, but sadly the bug persists.
Here is the output of python -m torch.utils.collect_env

PyTorch version: 1.9.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1 
CMake version: Could not collect
Libc version: glibc-2.10

Python version: 3.7 (64-bit runtime)
Python platform: Linux-5.4.0-73-generic-x86_64-with-debian-bullseye-sid
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 460.73.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] backpack-for-pytorch==1.2.0
[pip] numpy==1.20.2
[pip] pytorch-memlab==0.2.3
[pip] torch==1.9.0
[pip] torchvision==0.10.0
[conda] backpack-for-pytorch      1.2.0                     dev_0    <develop>
[conda] numpy                     1.20.2                   pypi_0    pypi
[conda] pytorch-memlab            0.2.3                    pypi_0    pypi
[conda] torch                     1.9.0                    pypi_0    pypi
[conda] torchvision               0.10.0                   pypi_0    pypi
1 Like

Thank you very much for the code snippet as well as the additional test using 1.9.0.
I can reproduce the issue on other devices as well and have created an issue to track and fix it.