Batchnorm 1D Cuda error

I am running the SimSiam code as in the official facebook repo: https://github.com/facebookresearch/simsiam

The code runs for 1 or 2 steps and then I get the cuda illegal memory access error. This is exactly the same as the question posed in Batchnorm1D - CUDA error: an illegal memory access was encountered.

I have running it on different versions of torch on multiple Tesla V100-SXM2-16GB GPUs. More concretely, I have tried the code on

PyTorch version: 1.10.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0

and also on

torch ==1.9.0+cu102

Cuda versions for both Pytorch versions are 10.0.130.
Here is the complete version of the error:

Epoch: [0][   0/1476]   Time 13.362 (13.362)    Data 12.197 (12.197)    Loss -0.0016 (-0.0016)
terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: an illegal memory access was encountered
Exception raised from create_event_internal at /pytorch/c10/cuda/CUDACachingAllocator.cpp:1055 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f1dbd189a22 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x10983 (0x7f1dbd3ea983 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1a7 (0x7f1dbd3ec027 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::TensorImpl::release_resources() + 0x54 (0x7f1dbd1735a4 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #4: <unknown function> + 0xa2ef72 (0x7f1e13ce8f72 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0xa2f011 (0x7f1e13ce9011 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x183dd6 (0x56207558add6 in /home/ubuntu/anaconda3/bin/python)
frame #7: <unknown function> + 0xe5c62 (0x5620754ecc62 in /home/ubuntu/anaconda3/bin/python)
frame #8: <unknown function> + 0xe5928 (0x5620754ec928 in /home/ubuntu/anaconda3/bin/python)
frame #9: <unknown function> + 0xe62c8 (0x5620754ed2c8 in /home/ubuntu/anaconda3/bin/python)
frame #10: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #11: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #12: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #13: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #14: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #15: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #16: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #17: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #18: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #19: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #20: <unknown function> + 0xe62de (0x5620754ed2de in /home/ubuntu/anaconda3/bin/python)
frame #21: <unknown function> + 0xe6c34 (0x5620754edc34 in /home/ubuntu/anaconda3/bin/python)
frame #22: <unknown function> + 0x1c612b (0x5620755cd12b in /home/ubuntu/anaconda3/bin/python)
frame #23: <unknown function> + 0x1c6253 (0x5620755cd253 in /home/ubuntu/anaconda3/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x2c82 (0x5620755bf7c2 in /home/ubuntu/anaconda3/bin/python)
frame #25: _PyFunction_FastCallKeywords + 0x184 (0x562075532e94 in /home/ubuntu/anaconda3/bin/python)
frame #26: <unknown function> + 0x18f9c8 (0x5620755969c8 in /home/ubuntu/anaconda3/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0xa04 (0x5620755bd544 in /home/ubuntu/anaconda3/bin/python)
frame #28: _PyFunction_FastCallKeywords + 0x184 (0x562075532e94 in /home/ubuntu/anaconda3/bin/python)
frame #29: <unknown function> + 0x18f9c8 (0x5620755969c8 in /home/ubuntu/anaconda3/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x96c (0x5620755bd4ac in /home/ubuntu/anaconda3/bin/python)
frame #31: _PyEval_EvalCodeWithName + 0x242 (0x562075504af2 in /home/ubuntu/anaconda3/bin/python)
frame #32: _PyFunction_FastCallKeywords + 0x320 (0x562075533030 in /home/ubuntu/anaconda3/bin/python)
frame #33: <unknown function> + 0x18f9c8 (0x5620755969c8 in /home/ubuntu/anaconda3/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x1699 (0x5620755be1d9 in /home/ubuntu/anaconda3/bin/python)
frame #35: _PyEval_EvalCodeWithName + 0x242 (0x562075504af2 in /home/ubuntu/anaconda3/bin/python)
frame #36: PyEval_EvalCodeEx + 0x39 (0x562075505d09 in /home/ubuntu/anaconda3/bin/python)
frame #37: PyEval_EvalCode + 0x1b (0x5620755e08ab in /home/ubuntu/anaconda3/bin/python)
frame #38: <unknown function> + 0x23df53 (0x562075644f53 in /home/ubuntu/anaconda3/bin/python)
frame #39: PyRun_StringFlags + 0x7d (0x56207564e8cd in /home/ubuntu/anaconda3/bin/python)
frame #40: PyRun_SimpleStringFlags + 0x3d (0x56207564e92d in /home/ubuntu/anaconda3/bin/python)
frame #41: <unknown function> + 0x248487 (0x56207564f487 in /home/ubuntu/anaconda3/bin/python)
frame #42: _Py_UnixMain + 0x3c (0x56207564f85c in /home/ubuntu/anaconda3/bin/python)
frame #43: __libc_start_main + 0xe7 (0x7f1e25cddbf7 in /lib/x86_64-linux-gnu/libc.so.6)
frame #44: <unknown function> + 0x1c5901 (0x5620755cc901 in /home/ubuntu/anaconda3/bin/python)

terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: an illegal memory access was encountered
Exception raised from create_event_internal at /pytorch/c10/cuda/CUDACachingAllocator.cpp:1055 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f324f882a22 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x10983 (0x7f324fae3983 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1a7 (0x7f324fae5027 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::TensorImpl::release_resources() + 0x54 (0x7f324f86c5a4 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #4: <unknown function> + 0xa2ef72 (0x7f32a63e1f72 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0xa2f011 (0x7f32a63e2011 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x183dd6 (0x557d5a116dd6 in /home/ubuntu/anaconda3/bin/python)
frame #7: <unknown function> + 0xe5c62 (0x557d5a078c62 in /home/ubuntu/anaconda3/bin/python)
frame #8: <unknown function> + 0xe5928 (0x557d5a078928 in /home/ubuntu/anaconda3/bin/python)
frame #9: <unknown function> + 0xe62c8 (0x557d5a0792c8 in /home/ubuntu/anaconda3/bin/python)
frame #10: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #11: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #12: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #13: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #14: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #15: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #16: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #17: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #18: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #19: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #20: <unknown function> + 0xe62de (0x557d5a0792de in /home/ubuntu/anaconda3/bin/python)
frame #21: <unknown function> + 0xe6c34 (0x557d5a079c34 in /home/ubuntu/anaconda3/bin/python)
frame #22: <unknown function> + 0x1c612b (0x557d5a15912b in /home/ubuntu/anaconda3/bin/python)
frame #23: <unknown function> + 0x1c6253 (0x557d5a159253 in /home/ubuntu/anaconda3/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x2c82 (0x557d5a14b7c2 in /home/ubuntu/anaconda3/bin/python)
frame #25: _PyFunction_FastCallKeywords + 0x184 (0x557d5a0bee94 in /home/ubuntu/anaconda3/bin/python)
frame #26: <unknown function> + 0x18f9c8 (0x557d5a1229c8 in /home/ubuntu/anaconda3/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0xa04 (0x557d5a149544 in /home/ubuntu/anaconda3/bin/python)
frame #28: _PyFunction_FastCallKeywords + 0x184 (0x557d5a0bee94 in /home/ubuntu/anaconda3/bin/python)
frame #29: <unknown function> + 0x18f9c8 (0x557d5a1229c8 in /home/ubuntu/anaconda3/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x96c (0x557d5a1494ac in /home/ubuntu/anaconda3/bin/python)
frame #31: _PyEval_EvalCodeWithName + 0x242 (0x557d5a090af2 in /home/ubuntu/anaconda3/bin/python)
frame #32: _PyFunction_FastCallKeywords + 0x320 (0x557d5a0bf030 in /home/ubuntu/anaconda3/bin/python)
frame #33: <unknown function> + 0x18f9c8 (0x557d5a1229c8 in /home/ubuntu/anaconda3/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x1699 (0x557d5a14a1d9 in /home/ubuntu/anaconda3/bin/python)
frame #35: _PyEval_EvalCodeWithName + 0x242 (0x557d5a090af2 in /home/ubuntu/anaconda3/bin/python)
frame #36: PyEval_EvalCodeEx + 0x39 (0x557d5a091d09 in /home/ubuntu/anaconda3/bin/python)
frame #37: PyEval_EvalCode + 0x1b (0x557d5a16c8ab in /home/ubuntu/anaconda3/bin/python)
frame #38: <unknown function> + 0x23df53 (0x557d5a1d0f53 in /home/ubuntu/anaconda3/bin/python)
frame #39: PyRun_StringFlags + 0x7d (0x557d5a1da8cd in /home/ubuntu/anaconda3/bin/python)
frame #40: PyRun_SimpleStringFlags + 0x3d (0x557d5a1da92d in /home/ubuntu/anaconda3/bin/python)
frame #41: <unknown function> + 0x248487 (0x557d5a1db487 in /home/ubuntu/anaconda3/bin/python)
frame #42: _Py_UnixMain + 0x3c (0x557d5a1db85c in /home/ubuntu/anaconda3/bin/python)
frame #43: __libc_start_main + 0xe7 (0x7f32b83d6bf7 in /lib/x86_64-linux-gnu/libc.so.6)
frame #44: <unknown function> + 0x1c5901 (0x557d5a158901 in /home/ubuntu/anaconda3/bin/python)

terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: an illegal memory access was encountered
Exception raised from create_event_internal at /pytorch/c10/cuda/CUDACachingAllocator.cpp:1055 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f910eb15a22 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x10983 (0x7f910ed76983 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1a7 (0x7f910ed78027 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::TensorImpl::release_resources() + 0x54 (0x7f910eaff5a4 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #4: <unknown function> + 0xa2ef72 (0x7f9165674f72 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0xa2f011 (0x7f9165675011 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x183dd6 (0x5585ccfafdd6 in /home/ubuntu/anaconda3/bin/python)
frame #7: <unknown function> + 0xe5c62 (0x5585ccf11c62 in /home/ubuntu/anaconda3/bin/python)
frame #8: <unknown function> + 0xe5928 (0x5585ccf11928 in /home/ubuntu/anaconda3/bin/python)
frame #9: <unknown function> + 0xe62c8 (0x5585ccf122c8 in /home/ubuntu/anaconda3/bin/python)
frame #10: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #11: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #12: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #13: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #14: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #15: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #16: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #17: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #18: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #19: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #20: <unknown function> + 0xe62de (0x5585ccf122de in /home/ubuntu/anaconda3/bin/python)
frame #21: <unknown function> + 0xe6c34 (0x5585ccf12c34 in /home/ubuntu/anaconda3/bin/python)
frame #22: <unknown function> + 0x1c612b (0x5585ccff212b in /home/ubuntu/anaconda3/bin/python)
frame #23: <unknown function> + 0x1c6253 (0x5585ccff2253 in /home/ubuntu/anaconda3/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x2c82 (0x5585ccfe47c2 in /home/ubuntu/anaconda3/bin/python)
frame #25: _PyFunction_FastCallKeywords + 0x184 (0x5585ccf57e94 in /home/ubuntu/anaconda3/bin/python)
frame #26: <unknown function> + 0x18f9c8 (0x5585ccfbb9c8 in /home/ubuntu/anaconda3/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0xa04 (0x5585ccfe2544 in /home/ubuntu/anaconda3/bin/python)
frame #28: _PyFunction_FastCallKeywords + 0x184 (0x5585ccf57e94 in /home/ubuntu/anaconda3/bin/python)
frame #29: <unknown function> + 0x18f9c8 (0x5585ccfbb9c8 in /home/ubuntu/anaconda3/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x96c (0x5585ccfe24ac in /home/ubuntu/anaconda3/bin/python)
frame #31: _PyEval_EvalCodeWithName + 0x242 (0x5585ccf29af2 in /home/ubuntu/anaconda3/bin/python)
frame #32: _PyFunction_FastCallKeywords + 0x320 (0x5585ccf58030 in /home/ubuntu/anaconda3/bin/python)
frame #33: <unknown function> + 0x18f9c8 (0x5585ccfbb9c8 in /home/ubuntu/anaconda3/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x1699 (0x5585ccfe31d9 in /home/ubuntu/anaconda3/bin/python)
frame #35: _PyEval_EvalCodeWithName + 0x242 (0x5585ccf29af2 in /home/ubuntu/anaconda3/bin/python)
frame #36: PyEval_EvalCodeEx + 0x39 (0x5585ccf2ad09 in /home/ubuntu/anaconda3/bin/python)
frame #37: PyEval_EvalCode + 0x1b (0x5585cd0058ab in /home/ubuntu/anaconda3/bin/python)
frame #38: <unknown function> + 0x23df53 (0x5585cd069f53 in /home/ubuntu/anaconda3/bin/python)
frame #39: PyRun_StringFlags + 0x7d (0x5585cd0738cd in /home/ubuntu/anaconda3/bin/python)
frame #40: PyRun_SimpleStringFlags + 0x3d (0x5585cd07392d in /home/ubuntu/anaconda3/bin/python)
frame #41: <unknown function> + 0x248487 (0x5585cd074487 in /home/ubuntu/anaconda3/bin/python)
frame #42: _Py_UnixMain + 0x3c (0x5585cd07485c in /home/ubuntu/anaconda3/bin/python)
frame #43: __libc_start_main + 0xe7 (0x7f9177669bf7 in /lib/x86_64-linux-gnu/libc.so.6)
frame #44: <unknown function> + 0x1c5901 (0x5585ccff1901 in /home/ubuntu/anaconda3/bin/python)

terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: an illegal memory access was encountered
Exception raised from create_event_internal at /pytorch/c10/cuda/CUDACachingAllocator.cpp:1055 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7fad064d3a22 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x10983 (0x7fad06734983 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #2: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x1a7 (0x7fad06736027 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10::TensorImpl::release_resources() + 0x54 (0x7fad064bd5a4 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #4: <unknown function> + 0xa2ef72 (0x7fad5d032f72 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0xa2f011 (0x7fad5d033011 in /home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x183dd6 (0x5564b989bdd6 in /home/ubuntu/anaconda3/bin/python)
frame #7: <unknown function> + 0xe5c62 (0x5564b97fdc62 in /home/ubuntu/anaconda3/bin/python)
frame #8: <unknown function> + 0xe5928 (0x5564b97fd928 in /home/ubuntu/anaconda3/bin/python)
frame #9: <unknown function> + 0xe62c8 (0x5564b97fe2c8 in /home/ubuntu/anaconda3/bin/python)
frame #10: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #11: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #12: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #13: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #14: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #15: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #16: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #17: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #18: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #19: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #20: <unknown function> + 0xe62de (0x5564b97fe2de in /home/ubuntu/anaconda3/bin/python)
frame #21: <unknown function> + 0xe6c34 (0x5564b97fec34 in /home/ubuntu/anaconda3/bin/python)
frame #22: <unknown function> + 0x1c612b (0x5564b98de12b in /home/ubuntu/anaconda3/bin/python)
frame #23: <unknown function> + 0x1c6253 (0x5564b98de253 in /home/ubuntu/anaconda3/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x2c82 (0x5564b98d07c2 in /home/ubuntu/anaconda3/bin/python)
frame #25: _PyFunction_FastCallKeywords + 0x184 (0x5564b9843e94 in /home/ubuntu/anaconda3/bin/python)
frame #26: <unknown function> + 0x18f9c8 (0x5564b98a79c8 in /home/ubuntu/anaconda3/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0xa04 (0x5564b98ce544 in /home/ubuntu/anaconda3/bin/python)
frame #28: _PyFunction_FastCallKeywords + 0x184 (0x5564b9843e94 in /home/ubuntu/anaconda3/bin/python)
frame #29: <unknown function> + 0x18f9c8 (0x5564b98a79c8 in /home/ubuntu/anaconda3/bin/python)
frame #30: _PyEval_EvalFrameDefault + 0x96c (0x5564b98ce4ac in /home/ubuntu/anaconda3/bin/python)
frame #31: _PyEval_EvalCodeWithName + 0x242 (0x5564b9815af2 in /home/ubuntu/anaconda3/bin/python)
frame #32: _PyFunction_FastCallKeywords + 0x320 (0x5564b9844030 in /home/ubuntu/anaconda3/bin/python)
frame #33: <unknown function> + 0x18f9c8 (0x5564b98a79c8 in /home/ubuntu/anaconda3/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x1699 (0x5564b98cf1d9 in /home/ubuntu/anaconda3/bin/python)
frame #35: _PyEval_EvalCodeWithName + 0x242 (0x5564b9815af2 in /home/ubuntu/anaconda3/bin/python)
frame #36: PyEval_EvalCodeEx + 0x39 (0x5564b9816d09 in /home/ubuntu/anaconda3/bin/python)
frame #37: PyEval_EvalCode + 0x1b (0x5564b98f18ab in /home/ubuntu/anaconda3/bin/python)
frame #38: <unknown function> + 0x23df53 (0x5564b9955f53 in /home/ubuntu/anaconda3/bin/python)
frame #39: PyRun_StringFlags + 0x7d (0x5564b995f8cd in /home/ubuntu/anaconda3/bin/python)
frame #40: PyRun_SimpleStringFlags + 0x3d (0x5564b995f92d in /home/ubuntu/anaconda3/bin/python)
frame #41: <unknown function> + 0x248487 (0x5564b9960487 in /home/ubuntu/anaconda3/bin/python)
frame #42: _Py_UnixMain + 0x3c (0x5564b996085c in /home/ubuntu/anaconda3/bin/python)
frame #43: __libc_start_main + 0xe7 (0x7fad6f027bf7 in /lib/x86_64-linux-gnu/libc.so.6)
frame #44: <unknown function> + 0x1c5901 (0x5564b98dd901 in /home/ubuntu/anaconda3/bin/python)

Traceback (most recent call last):
  File "run_simsiam.py", line 346, in <module>
    main()
  File "run_simsiam.py", line 138, in main
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 2 terminated with the following error:
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/ubuntu/nlp_ssl/run_simsiam.py", line 271, in main_worker
    train(train_loader, model, criterion, optimizer, epoch, lr_scheduler, args)
  File "/home/ubuntu/nlp_ssl/run_simsiam.py", line 309, in train
    p1, p2, z1, z2 = model(input_1, input_2)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 799, in forward
    output = self.module(*inputs, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/nlp_ssl/model.py", line 69, in forward
    z1 = self.projector(z1) # NxD
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 757, in forward
    world_size,
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/_functions.py", line 56, in forward
    out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
RuntimeError: CUDA error: an illegal memory access was encountered

My batch sizes are 64 or 128 depending on the GPUs I am using. The issue goes away when I remove the batchnorm layers. My question is : Is this issue solved in torch 1.10 nightly build since I am getting the same error in torch 1.10.1 and torch 1.9?

Do you see this BatchNorm error in local training or is it specifically related to distributed training?

cc @ptrblck

Could you post the command you’ve used to run the training or (even better) a minimal code snippet to create the model instance and the input shapes of all tensors to reproduce the issue, please?

I still see the error for 1 GPU training on torch 1.9. However if I remove the batchnorm layer from the predictor module the code runs. However, the code runs fine on 1 GPU on torch 1.10.1.

Epoch: [0][   0/5906]   Time  0.732 ( 0.732)    Data  0.017 ( 0.017)    Loss 0.0012 (0.0012)
Traceback (most recent call last):
  File "test_one_gpu.py", line 92, in <module>
    train(train_loader, model, criterion, optimizer, epoch, lr_scheduler)
  File "test_one_gpu.py", line 78, in train
    z2 = self.projector(z2) # NxD
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 178, in forward
    self.eps,
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 2282, in batch_norm
    input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: CUDA error: an illegal memory access was encountered

I will post the model definition and the expected sizes of tensors in response to Patrick’s answer. Thank you.

Here is the model description. I am using SimSiam to train some sentence transformers.

class SimSiam(nn.Module):
    """
    Build a SimSiam model.
    """
    def __init__(self, dim=2048, pred_dim=512, dropout_prob=.5):
        """
        dim: feature dimension (default: 2048)
        pred_dim: hidden dimension of the predictor (default: 512)
        pred_dim = dim//4
        """
        super(SimSiam, self).__init__()

        # create the encoder
        # num_classes is the output fc dimension, zero-initialize last BNs
        self.path = path
        self.encoder = BertModel.from_pretrained('bert-base-uncased')
        self.dropout_prob = dropout_prob
        self.dropout = nn.Dropout(dropout_prob)

        # build a 3-layer projector
        prev_dim = self.encoder.config.hidden_size  #768
        self.projector = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False),  
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # first layer, output_dim = 768
                                        nn.Linear(prev_dim, prev_dim, bias=False),
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # second layer, output_dim=768
                                        nn.Linear(prev_dim, dim), 
                                         nn.BatchNorm1d(dim, affine=False)
                                        ) # output layer output_dim = 2048
        self.projector[6].bias.requires_grad = False # hack: not use bias as it is followed by BN

        # build a 2-layer predictor
        self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False), #output_dim=512
                                        nn.BatchNorm1d(pred_dim), #commenting this out makes the training run on 1 GPU on torch 1.9
                                        nn.ReLU(inplace=True), # hidden layer
                                        nn.Linear(pred_dim, dim)) # output layer, output_dim=2048

    def forward(self, input_ids_1,attention_mask_1, input_ids_2,attention_mask_2):
        """
        Input:
            x1: first views of input data
            x2: second views of input data
        Output:
            p1, p2, z1, z2: predictors and targets of the network
            See Sec. 3 of https://arxiv.org/abs/2011.10566 for detailed notations
        """

        # compute features for one view
        z1 = self.encoder(input_ids=input_ids_1, attention_mask=attention_mask_1)
        z1 = self.dropout(z1[1]) # Nx768
        z1 = self.projector(z1) # Nx2048

        z2 = self.encoder(input_ids=input_ids_2, attention_mask=attention_mask_2)
        z2 = self.dropout(z2[1]) # Nx768
        z2 = self.projector(z2) # Nx2048
        

        p1 = self.predictor(z1) # Nx2048
        p2 = self.predictor(z2) # Nx2048

        return p1, p2, z1.detach(), z2.detach()

I am launching the code by

CUDA_LAUNCH_BLOCKING=1 python run_simsiam.py \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0

I will post a minimal example of the training script later today.

Here is the MWE using the same above model.

import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed

from torch.utils.data import Dataset

from transformers import BertTokenizer, BertModel, AutoTokenizer
from transformers import AdamW

from model import *

'''
To run this code: python mwe.py \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0
'''

#Add in other backbones like biomed Roberta and clicnical longformer

parser = argparse.ArgumentParser(description='SimSiam Training')
parser.add_argument('--data', metavar='DIR', default='/home/ubuntu/ssl_data/all_body_part_data_02032022.pkl',
                    help='path to dataset')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')
parser.add_argument('--batch_size', default=64, type=int)


args = parser.parse_args()

#simulate tokenized data
inps = torch.randint(0,30000, size=(10000, 256))
att = torch.ones((10000,256))

class My_Dataset(Dataset):
    def __init__(self, X,Y):
        self.X = X
        self.Y = Y
    
    def __len__(self):
        return(len(self.X))

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

ds = My_Dataset(inps,att)
lr = .0001

def main():

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])
    
    cudnn.deterministic = False
    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:
        def print_pass(*args):
            pass
        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        torch.distributed.barrier()
    # create model
    model = SimSiam()
    

    # infer learning rate before changing batch size
    init_lr = lr * args.batch_size / 256

    if args.distributed:
        # Apply SyncBN
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")


    # define loss function (criterion) and optimizer
    criterion = nn.CosineSimilarity(dim=1).cuda(args.gpu) 

    optimizer = AdamW(model.parameters(), init_lr,
                                weight_decay=1e-4)


    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(ds)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        ds, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)


    for epoch in range(100):
        if args.distributed:
            train_sampler.set_epoch(epoch)
 
        train(train_loader, model, criterion, optimizer, args)


def train(train_loader, model, criterion, optimizer, args):

    # switch to train mode
    model.train()
    
    for i, batch in enumerate(train_loader):

        # measure data loading time
        optimizer.zero_grad()

        batch[0] = batch[0].cuda(args.gpu, non_blocking=True)
        batch[1] = batch[1].cuda(args.gpu, non_blocking=True)
            

        # compute output and loss
        p1, p2, z1, z2 = model(input_ids_1=batch[0],attention_mask_1=batch[1] , input_ids_2=batch[0],attention_mask_2=batch[1] ) #use the same data twice.
        loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5

        loss.backward()
        optimizer.step()

        if i%20 == 0:
            print(loss.item()) #print something

if __name__ == '__main__':
    main()