I’ve encountered a very weird problem:
loss.backward()
for my model works on every machine I have access to, namely V100 and 2080ti, but it just fails on the 4090, suggesting that:
In [2]: print(loss)
tensor(0.3055, device='cuda:1', grad_fn=<MeanBackward0>)
In [3]: loss.backward()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[3], line 1
----> 1 loss.backward()
File ~/bin/anaconda3/envs/torch2.0/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
477 if has_torch_function_unary(self):
478 return handle_torch_function(
479 Tensor.backward,
480 (self,),
(...)
485 inputs=inputs,
486 )
--> 487 torch.autograd.backward(
488 self, gradient, retain_graph, create_graph, inputs=inputs
489 )
File ~/bin/anaconda3/envs/torch2.0/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
195 retain_graph = create_graph
197 # The reason we repeat same the comment below is that
198 # some Python versions print out the first line of a multi-line function
199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors_, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)
RuntimeError: CUDA error: invalid argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
All their drivers and cuda are install with the same cuda_12.1.1_530.30.02_linux.run
file, and the torch version is all 2.0 (1.x is also ok on v100 and 2080ti).
Another weird thing is that on my 4090 it seems only this type of model can’t backward loss. My model is a Conformer based u-net. For other models, including a Transformer based u-net, the loss seems can backward normally, but I’ve not done extensive test on many architectures.
I guess that this issue is most likely caused by the GPU or CUDA, but I can’t find any useful information. Does anyone have any clue? Thanks!