Functional.max_pool1d() CUDA error on large tensor

Using torch.nn.functional.max_pool1d on a 3D-tensor with large second dimension (>= 10^6 elements) gives a CUDA-related error I could not make sense of. Reducing the second dimension solves the problem, so it might have to do with the available GPU memory. Note that I already tried applying x-contiguous() as suggested in this issue.

With the minimal example below I can reproduce the problem with the following settings:

  • PyTorch nightly version 1.6.0.dev20200514
  • CUDA 10.1
  • Python 3.7.7
  • GPU: GeForce RTX 2080 Ti

Minimal example and traceback:

import torch
import torch.nn.functional as F

x = 6.
size = 10. ** 6
print(f"Size of second dimension: {int(size)}")
poolsize = 4
x = torch.rand((1,int(size),poolsize), requires_grad=True).to('cuda:0')
y = F.max_pool1d(x.contiguous(), kernel_size=poolsize, stride=None, padding=0)

loss = torch.sum(y)
loss.backward()

"""
Traceback (most recent call last):
  File "efficient-segmentation/minimal_example.py", line 14, in <module>
    loss.backward()
  File "/home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/tensor.py", line 184, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/autograd/__init__.py", line 123, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA error: an illegal memory access was encountered
Exception raised from copy_kernel_cuda at /opt/conda/conda-bld/pytorch_1589440011545/work/aten/src/ATen/native/cuda/Copy.cu:200 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x4d (0x7f79c52dc48d in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x254fe57 (0x7f79c7a5ae57 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0x92f489 (0x7f79f246e489 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x92d7c2 (0x7f79f246c7c2 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #4: at::native::copy_(at::Tensor&, at::Tensor const&, bool) + 0x44 (0x7f79f246db04 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x31b2198 (0x7f79f4cf1198 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0xba40a1 (0x7f79f26e30a1 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #7: at::native::to(at::Tensor const&, c10::TensorOptions const&, bool, bool, c10::optional<c10::MemoryFormat>) + 0x802 (0x7f79f26e3d72 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xef07ca (0x7f79f2a2f7ca in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x29a95ab (0x7f79f44e85ab in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0xe22292 (0x7f79f2961292 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #11: torch::autograd::CopyBackwards::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x495 (0x7f79f489df55 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x2d58847 (0x7f79f4897847 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #13: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1740 (0x7f79f48929d0 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #14: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x4ee (0x7f79f489382e in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #15: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x94 (0x7f79f4889b54 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #16: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x40 (0x7f79f8458820 in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #17: <unknown function> + 0xc819d (0x7f79fb12019d in /home/nfs/username/miniconda3/envs/pytorch-cuda10/lib/python3.7/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #18: <unknown function> + 0x7ea5 (0x7f7a19b1fea5 in /lib64/libpthread.so.0)
frame #19: clone + 0x6d (0x7f7a198488dd in /lib64/libc.so.6)
"""

Thanks for the reproducible code snippet!
Your code should be correct and we are most likely creating an illegal memory access internally.
I’ve reproduced the bug, will created and issue and we’ll fix it.
Once it’s fixed I’ll ping you here.

EDIT: GitHub issue to track and fix it.
Thanks again for reporting and narrowing down this issue, @rjbruin! :slight_smile:

1 Like