Segfault / SIGSEGV using torchvision.ops.RoIAlign

Howdy,

Applying RoIAlign (torchvision.ops.RoIAlign) on rois that is a tensor of shape (K, 5) results in a segfault : Process finished with exit code 139 (interrupted by signal 11: SIGSEGV).

The faulty behavior is not observed when :

  • Applying RoIAlign on rois that is a list of tensor of shape (4,)
  • RoIAlign input size is small enough (typically LEQ 32,10,10)

Here is a code snippet that should allow you to reproduce my observations.

import argparse

import torch
from torchvision.ops import RoIAlign

NB_ROI_PER_DOC = 5

if __name__ == "__main__":

    torch.manual_seed(0)

    parser = argparse.ArgumentParser(description='minimum reproducible example')
    parser.add_argument('--bug', action='store_true', default=False, help='toggle on the issue')
    parser.add_argument('--batch-size', type=int, default=1)
    parser.add_argument('--input-size', type=str, default="64,10,10")
    args = parser.parse_args()

    align = RoIAlign((3, 3), spatial_scale=14 / 224, sampling_ratio=2)
    batch_size = args.batch_size
    input_ = torch.randn((batch_size, *(int(dim) for dim in args.input_size.split(","))))
    rois = [torch.abs(torch.randn(NB_ROI_PER_DOC, 4)) for _ in input_]

    if args.bug:
        nb_rois = NB_ROI_PER_DOC * batch_size
        boxes = torch.abs(torch.randn(nb_rois, 4))
        roi_ids = torch.arange(nb_rois).view(-1, 1)
        rois = torch.cat((roi_ids, boxes), dim=1)

    output = align(input=input_, rois=rois)
    print(rois)
    print(output.shape)

Running the code in docker to be freeze packages versions and the like.

FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-runtime
RUN mkdir -p /opt/debug
WORKDIR /opt/debug
ADD mre.py .
ENTRYPOINT ["python", "/opt/debug/mre.py"]

Sample outputs :

Problematic behavior, to run : build -f Dockerfile.minimal -t debug-torch:v0 . ; docker run docker run debug-torch:v0 --bug; docker ps -a |grep debug-torch:v0 results with:

46b40fe6efee   debug-torch:v0 "python /opt/debug/m…"     16 seconds ago   Exited (139) 

Expected behavior, to run : docker run docker run debug-torch:v0 results with:

[tensor([[1.4164, 0.2379, 0.9334, 1.1331],
        [0.3530, 2.0928, 0.6356, 1.5069],
        [0.9527, 1.0599, 0.9549, 1.3355],
        [0.5251, 0.7416, 0.4269, 0.4008],
        [0.7872, 0.0834, 1.1256, 1.5490]])]
torch.Size([5, 64, 3, 3])

To reduce the input size solves the problem, to run docker run docker run debug-torch:v0 --bug --input-size 5,5,5 results with:

tensor([[0.0000, 0.3584, 1.5616, 0.3546, 1.0811],
        [1.0000, 0.8760, 0.2871, 1.0216, 0.5111],
        [2.0000, 1.7137, 0.5101, 0.4749, 0.6334],
        [3.0000, 1.2063, 0.6074, 0.5472, 1.1005],
        [4.0000, 0.7201, 0.0119, 0.3398, 0.2635]])
torch.Size([5, 64, 3, 3])

I’m well aware that the most likely reason for the problem is me missing out something obvious in the using of RoIAlign, feel free to simply let me know if this is the case.

Thank for reading down to this point & have a nice day

Hi Francois!

I can’t reproduce your issue running a tweaked version of your code
by pasting it into a python interpreter running with torch / torchvision
versions 1.9.0 / 0.10.0:

Python 3.6.5 |Anaconda, Inc.| (default, Mar 29 2018, 13:32:41) [MSC v.1900 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import argparse
>>>
>>> import torch
>>> from torchvision.ops import RoIAlign
>>>
>>> NB_ROI_PER_DOC = 5
>>>
>>> if __name__ == "__main__":
...
...     torch.manual_seed(0)
...
...     parser = argparse.ArgumentParser(description='minimum reproducible example')
...     parser.add_argument('--bug', action='store_true', default=False, help='toggle on the issue')
...     parser.add_argument('--batch-size', type=int, default=1)
...     parser.add_argument('--input-size', type=str, default="64,10,10")
...     args = parser.parse_args()
...
...     align = RoIAlign((3, 3), spatial_scale=14 / 224, sampling_ratio=2)
...     batch_size = args.batch_size
...     input_ = torch.randn((batch_size, *(int(dim) for dim in args.input_size.split(","))))
...     rois = [torch.abs(torch.randn(NB_ROI_PER_DOC, 4)) for _ in input_]
...
...     args.bug = True
...     if args.bug:
...         print ('args.bug =', args.bug)
...         nb_rois = NB_ROI_PER_DOC * batch_size
...         boxes = torch.abs(torch.randn(nb_rois, 4))
...         roi_ids = torch.arange(nb_rois).view(-1, 1)
...         rois = torch.cat((roi_ids, boxes), dim=1)
...
...     output = align(input=input_, rois=rois)
...     print(rois)
...     print(output.shape)
...
<torch._C.Generator object at 0x000002C110C51978>
_StoreTrueAction(option_strings=['--bug'], dest='bug', nargs=0, const=True, default=False, type=None, choices=None, help='toggle on the issue', metavar=None)
_StoreAction(option_strings=['--batch-size'], dest='batch_size', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, help=None, metavar=None)
_StoreAction(option_strings=['--input-size'], dest='input_size', nargs=None, const=None, default='64,10,10', type=<class 'str'>, choices=None, help=None, metavar=None)
args.bug = True
tensor([[0.0000, 0.7319, 1.0969, 0.6430, 0.8533],
        [1.0000, 1.2892, 0.2074, 0.4249, 0.0755],
        [2.0000, 1.9148, 2.2872, 0.1998, 0.9350],
        [3.0000, 0.5617, 0.9486, 1.7332, 0.7013],
        [4.0000, 0.1312, 3.0198, 1.7909, 0.4429]])
torch.Size([5, 64, 3, 3])
>>> print (torch.__version__)
1.9.0
>>> import torchvision
>>> print (torchvision.__version__)
0.10.0

Are you expecting this to be some kind of docker / configuration /
hardware issue, or do you think it’s a pure pytorch issue? I don’t see
it with my test. What versions of torch / torchvision are you using?

Can you reproduce your issue with a “raw” python / pytorch installation,
or does in only happen when you use docker?

Best.

K. Frank

Hey Frank, thank you for your attempt to reproduce the issue !

Yes, I do reproduce the issue in a virutal env (ie. w/o docker). Using https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py I get the environment specs below.

python collect_env.py
Collecting environment information...
PyTorch version: 1.10.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.10 (default, Sep 28 2021, 16:10:42)  [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.11.0-38-generic-x86_64-with-glibc2.29
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.3
[pip3] torch==1.10.0
[pip3] torchvision==0.11.1
[conda] Could not collect

Hi Francois!

I was also unable to reproduce your issue using a freshly-downloaded
stable version, 1.10.0 / 0.11.1, under conda on ubuntu:

args.bug = True
tensor([[0.0000, 0.7319, 1.0969, 0.6430, 0.8533],
        [1.0000, 1.2892, 0.2074, 0.4249, 0.0755],
        [2.0000, 1.9148, 2.2872, 0.1998, 0.9350],
        [3.0000, 0.5617, 0.9486, 1.7332, 0.7013],
        [4.0000, 0.1312, 3.0198, 1.7909, 0.4429]])
torch.Size([5, 64, 3, 3])
>>> print (torch.__version__)
1.10.0
>>> import torchvision
>>> print (torchvision.__version__)
0.11.1

(Again, I pasted my tweaked version of your code into a python
interpreter.)

Also, for what it’s worth, I could not reproduce your issue on a older
1.10.0 nightly version:

>>> print (torch.__version__)
1.10.0.dev20210830
>>> import torchvision
>>> print (torchvision.__version__)
0.11.0.dev20210830

Best.

K. Frank

Thank you Frank, this might have something to do with the fact I use Python version: 3.8.10 whereas you use Python 3.6.5 |Anaconda, Inc ?