I am finetuning a llama-2-7b using FSDP on 2 A100 80GBs. I am using the nightly version. My model runs for a couple of steps and crashes with the following stack trace:
Traceback (most recent call last):
File "/cluster/project/sachan/kushal/llama-exp/llama_rl_train.py", line 116, in train
train_loss.backward()
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/_tensor.py", line 491, in backward
torch.autograd.backward(
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/utils/checkpoint.py", line 1071, in unpack_hook
frame.recompute_fn(*args)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/utils/checkpoint.py", line 1194, in recompute_fn
fn(*args, **kwargs)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 305, in forward
query_states = self.q_proj(hidden_states)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/peft/tuners/lora.py", line 817, in forward
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
RuntimeError: setStorage: sizes [4096, 4096], strides [1, 4096], storage offset 0, and itemsize 2 requiring a storage size of 33554432 are out of bounds for storage of size 0
Exception raised from checkInBoundsForStorage at ../aten/src/ATen/native/Resize.h:92 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x2b8cc8668647 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x2b8cc86248f9 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x1f78319 (0x2b8c5eea2319 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #3: at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) + 0x104 (0x2b8c5ee99ee4 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x31883a5 (0x2b8c7790f3a5 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x31bff93 (0x2b8c77946f93 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #6: at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, c10::optional<c10::SymInt>) + 0x1e6 (0x2b8c5f35e4f6 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: at::Tensor::as_strided_symint(c10::ArrayRef<c10::SymInt>, c10::ArrayRef<c10::SymInt>, c10::optional<c10::SymInt>) const + 0x4a (0x2b8c5eea110a in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: at::native::transpose(at::Tensor const&, long, long) + 0x81b (0x2b8c5ee96adb in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x2c81b23 (0x2b8c5fbabb23 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::_ops::transpose_int::call(at::Tensor const&, long, long) + 0x15f (0x2b8c5f824e7f in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #11: at::native::t(at::Tensor const&) + 0x4b (0x2b8c5ee779bb in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x2c81aed (0x2b8c5fbabaed in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::_ops::t::redispatch(c10::DispatchKeySet, at::Tensor const&) + 0x6b (0x2b8c5f7b0c7b in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x4a0a7d0 (0x2b8c619347d0 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x4a0a9e0 (0x2b8c619349e0 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #16: at::_ops::t::redispatch(c10::DispatchKeySet, at::Tensor const&) + 0x6b (0x2b8c5f7b0c7b in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #17: <unknown function> + 0x43b0d79 (0x2b8c612dad79 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x43b1220 (0x2b8c612db220 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #19: at::_ops::t::call(at::Tensor const&) + 0x12b (0x2b8c5f7f580b in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #20: at::native::linear(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) + 0x230 (0x2b8c5ec28b20 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #21: <unknown function> + 0x2e54ca3 (0x2b8c5fd7eca3 in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #22: at::_ops::linear::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&) + 0x18f (0x2b8c5f38551f in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #23: <unknown function> + 0x67befa (0x2b8c5c319efa in /cluster/project/sachan/kushal/llenv/lib64/python3.8/site-packages/torch/lib/libtorch_python.so)
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
The model behaves differently with pytorch version 2.0.1+cu118 and results in a different (potentially similar root cause) error as detailed here. I have tried running with compute-sanitizer
and CUDA_LAUNCH_BLOCKING
but the stack traces remain identical. Any ideas to further debug this would be helpful.