RuntimeError: cuda runtime error (710) :

Hello,
I’m trying to implement neural networks pruning strategies, the code was working fine until I converted it from numpy to pytorch (only the pruning part) to leverage GPU speed.
The code is
The function that does the sparsification:

def magnitude_pruning(A, drop=0.2):
  if drop<0:
    drop=0
  shape_a = A.shape
  n_elem = A.numel()
  A = A.view(-1)
  n_drop = int(n_elem * drop)
  drop_idxs = torch.topk(A, n_drop, largest=False, sorted=False)[-1]
  mask = torch.ones(n_elem, device=device)
  mask[drop_idxs] = 0
  A = A * mask 
  A = A.view(shape_a)
  return A, mask.view(shape_a)

And here is a snippet where I use the function

os.environ["CUDA_LAUNCH_BLOCKING"]="1"
for key,p_drop in p.items():
  attributepath = key.split(".")
  cur = models["resnet18_srn"]
  for attr in attributepath[:-1]:
    if attr.isdecimal():
      cur = cur[int(attr)]
    else:
      cur = getattr(cur, attr)
  W_ = magnitude_pruning(cur.weight.detach(),p_drop)[0]

I get the following error

RuntimeError                              Traceback (most recent call last)
<ipython-input-13-899364821043> in <module>()
      9     else:
     10       cur = getattr(cur, attr)
---> 11   W_ = magnitude_pruning(cur.weight.detach(),value)[0]

<ipython-input-8-4d745d9774f3> in magnitude_pruning(A, drop)
      7   A = A.view(-1)
      8   n_drop = int(n_elem * drop)
----> 9   drop_idxs = torch.topk(A, n_drop, largest=False, sorted=False)[-1]
     10   mask = torch.ones(n_elem, device=device)
     11   mask[drop_idxs] = 0

RuntimeError: cuda runtime error (710) : device-side assert triggered at /pytorch/aten/src/THC/generic/THCTensorTopK.cu:188

I’m running on colab and I used both

os.environ["CUDA_LAUNCH_BLOCKING"]="1"
!export CUDA_LAUNCH_BLOCKING=1

to ensure cuda’s error is in the spot where it happens,
the value of the p dictionary (which represents the drop percentage) is as follows

{'conv1.weight': 0.5248447873548604,
 'last_linear.0.weight': 0.004661282702045355,
 'layer1.0.conv1.weight': 0.37808877777400185,
 'layer1.0.conv2.weight': 0.06504905686978213,
 'layer1.1.conv1.weight': 0.34116854079977843,
 'layer1.1.conv2.weight': -7.333395402042697e-08,
 'layer2.0.conv1.weight': 0.07128420402834179,
 'layer2.0.conv2.weight': -2.1180268428011573e-08,
 'layer2.0.downsample.0.weight': 0.5666850700441095,
 'layer2.1.conv1.weight': -2.1861888077623348e-08,
 'layer2.1.conv2.weight': 0.9691538018553951,
 'layer3.0.conv1.weight': 0.4329534168219621,
 'layer3.0.conv2.weight': 0.06542427263283024,
 'layer3.0.downsample.0.weight': 0.3568506078785666,
 'layer3.1.conv1.weight': 0.09328176622263307,
 'layer3.1.conv2.weight': 0.6633195976051018,
 'layer4.0.conv1.weight': 0.9914340620049251,
 'layer4.0.conv2.weight': 0.8091775692583867,
 'layer4.0.downsample.0.weight': 0.5552448512140766,
 'layer4.1.conv1.weight': 0.9907790148874497,
 'layer4.1.conv2.weight': 0.8827941427912476}

Note the error does not occur when using cpu, and when using different p dictionary like this one the error does not occur as well:

{'conv1.weight': 0.7392659868638237,
 'last_linear.0.weight': 0.10501561799105052,
 'layer1.0.conv1.weight': 0.6349508179361485,
 'layer1.0.conv2.weight': 0.44357375136506516,
 'layer1.1.conv1.weight': 0.6290701837014776,
 'layer1.1.conv2.weight': 0.36933085729608905,
 'layer2.0.conv1.weight': 0.4509790372596437,
 'layer2.0.conv2.weight': 0.42935426060761517,
 'layer2.0.downsample.0.weight': 0.7065802232955889,
 'layer2.1.conv1.weight': 0.3707469940086635,
 'layer2.1.conv2.weight': 0.7227469103776935,
 'layer3.0.conv1.weight': 0.7189300890084274,
 'layer3.0.conv2.weight': 0.46585970510469366,
 'layer3.0.downsample.0.weight': 0.6160370089895277,
 'layer3.1.conv1.weight': 0.5312890654313192,
 'layer3.1.conv2.weight': 0.8089039876828459,
 'layer4.0.conv1.weight': 0.8378576175338739,
 'layer4.0.conv2.weight': 0.7549824163148917,
 'layer4.0.downsample.0.weight': 0.6779744785558353,
 'layer4.1.conv1.weight': 0.9224107770405527,
 'layer4.1.conv2.weight': 0.7217466674953182}

Note that to ensure non-negative percentage in the function magnitude_pruning I zero negative values of drop.

Based on the error message and the line of code (assuming the blocking launch was working in the notebook), I guess you might use an invalid index in:

drop_idxs = torch.topk(A, n_drop, largest=False, sorted=False)[-1]

Run the code on the CPU to get a better error message or, if that doesn’t help, print A and n_drop and check for invalid indices.

Thank you for your response!
I tried running the code on CPU but the problem doesn’t occur there, So i don’t think the problem is from invalid indexing.
also when I changed the code (for logging purposes) to


def magnitude_pruning(A, drop=0.2):
  if drop<0:
    drop=0
  shape_a = A.shape
  n_elem = A.numel()
  A = A.view(-1)
  n_drop = int(n_elem * drop)
  print("....n_drop = ",n_drop)
  print("....A size = ", n_elem)
  top_k_res = torch.topk(A, n_drop, largest=False, sorted=False)
  print("...topk result = ",top_k_res)

  drop_idxs = top_k_res[1]
  mask = torch.ones(n_elem, device=device)
  mask[drop_idxs] = 0
  A = A * mask 
  A = A.view(shape_a)
  return A, mask.view(shape_a)

The error somewhat changed to this

....n_drop =  3816
....A size =  9408
...topk result =  torch.return_types.topk(
values=tensor([-0.0028, -0.0026, -0.0069,  ..., -0.0078, -0.0095, -0.0018],
       device='cuda:0'),
indices=tensor([   1,    2,    6,  ..., 9401, 9407, 8464], device='cuda:0'))
....n_drop =  13695
....A size =  36864
...topk result =  torch.return_types.topk(
values=tensor([-0.0278, -0.0324, -0.2429,  ..., -0.0029, -0.0083, -0.0025],
       device='cuda:0'),
indices=tensor([    1,     3,     4,  ..., 36851, 36863, 35230], device='cuda:0'))
....n_drop =  0
....A size =  36864
...topk result =  torch.return_types.topk(
values=tensor([], device='cuda:0'),
indices=tensor([], device='cuda:0', dtype=torch.int64))
....n_drop =  5047
....A size =  36864
...topk result =  
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-10-f45d5e9895b6> in <module>()
     10     else:
     11       cur = getattr(cur, attr)
---> 12   W_ = magnitude_pruning(cur.weight.detach(),value)[0]

5 frames
<ipython-input-7-51b1f75557a0> in magnitude_pruning(A, drop)
     10   print("....A size = ", n_elem)
     11   top_k_res = torch.topk(A, n_drop, largest=False, sorted=False)
---> 12   print("...topk result = ",top_k_res)
     13 
     14   drop_idxs = top_k_res[1]

/usr/local/lib/python3.6/dist-packages/torch/tensor.py in __repr__(self)
    177             return handle_torch_function(Tensor.__repr__, relevant_args, self)
    178         # All strings are unicode in Python 3.
--> 179         return torch._tensor_str._str(self)
    180 
    181     def backward(self, gradient=None, retain_graph=None, create_graph=False):

/usr/local/lib/python3.6/dist-packages/torch/_tensor_str.py in _str(self)
    370 def _str(self):
    371     with torch.no_grad():
--> 372         return _str_intern(self)

/usr/local/lib/python3.6/dist-packages/torch/_tensor_str.py in _str_intern(self)
    350                     tensor_str = _tensor_str(self.to_dense(), indent)
    351                 else:
--> 352                     tensor_str = _tensor_str(self, indent)
    353 
    354     if self.layout != torch.strided:

/usr/local/lib/python3.6/dist-packages/torch/_tensor_str.py in _tensor_str(self, indent)
    239         return _tensor_str_with_formatter(self, indent, summarize, real_formatter, imag_formatter)
    240     else:
--> 241         formatter = _Formatter(get_summarized_data(self) if summarize else self)
    242         return _tensor_str_with_formatter(self, indent, summarize, formatter)
    243 

/usr/local/lib/python3.6/dist-packages/torch/_tensor_str.py in __init__(self, tensor)
     87 
     88         else:
---> 89             nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0))
     90 
     91             if nonzero_finite_vals.numel() == 0:

RuntimeError: CUDA error: device-side assert triggered

I tried to look into the source of the problem and found that when the value of drop is very low (~0) is when it occur, so I tried using torch.topk on an arbitrary tensor with k=0, it works for the first time and then the next time it gives this error RuntimeError: CUDA error: device-side assert triggered , by now I can fix the issue by just adding an if statement, but I’m curious about the reason it happens,
Thank you!

The last stack trace is most likely pointing to the wrong line of code and the CUDA_LAUNCH_BLOCKING env var is not working properly inside your notebook.
I would generally recommend to export the notebook as a python script and run it via:

CUDA_LAUNCH_BLOCKING=1 python script.py args

to get the proper line of code.

That being said, I don’t quite understand the last statement.
Could you post a code snippet showing when the assert is triggered?

I mean that I found the source of the problem, using topk with k=0 makes the gpu “unusable” for the next operation,
example

x = torch.Tensor([1,2,3,4,5,2,1,-1,-4,12,4]).cuda()
torch.topk(x,k=0,largest=False)

and then running any command that utilizes gpu (I ran it on a different cell)

x = x**2

will result in an error,

RuntimeError                              Traceback (most recent call last)
<ipython-input-3-d4dbc0d35a05> in <module>()
----> 1 x = x**2

RuntimeError: CUDA error: device-side assert triggered

When I put it in a script and run it with CuDA_LAUNCH_BLOCKING I get the following stack trace (for the same snippet of code as above)

/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [0,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [1,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [2,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [3,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [4,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [5,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [6,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [7,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [8,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [9,0,0] Assertion `writeIndex < outputSliceSize` failed.
/pytorch/aten/src/THC/THCTensorTopK.cuh:107: gatherTopK: block: [0,0,0], thread: [10,0,0] Assertion `writeIndex < outputSliceSize` failed.
THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCTensorTopK.cu line=188 error=710 : device-side assert triggered
Traceback (most recent call last):
  File "script.py", line 3, in <module>
    torch.topk(x,k=0,largest=False)
RuntimeError: cuda runtime error (710) : device-side assert triggered at /pytorch/aten/src/THC/generic/THCTensorTopK.cu:188
1 Like

Thanks for the follow up!
This is indeed a bug and should be fixed. I’ll create an issue later to track it.