Dear PyTorch Community,
The source code for this question can be found here:
Google Colab Notebook saved into my Public Github Repository
I have been attempting to train a PyTorch model using the PyTorch XLA device (TPU) in Google Colab, but an Error is preventing me from completing it.
After a few different attempts of debugging the issue, I modified my _visual_to_stacked_tensor
in the DataCollator to the code below, in the attempt to print the shape
and device
of the vectors, to ensure they were all instantiated in the TPU device:
def _visual_to_stacked_tensor(self, videos: List[torch.Tensor]) -> torch.Tensor:
"""
Stack together tensors representing feature vectors of images.
:param record: The list of individual tensors for visual data
:return: The collated visual data as a tensor of shape (batch_size, *visual_in_features)
"""
print("data collator, visual:")
for v in videos:
print((v.device, v.shape))
return torch.stack(videos, dim=0)
However, an exception is exploding while trying to print the shape of one of the vectors.
My dataset is fairly large (it’s the MELD Dataset) and I load the entire dataset into the TPU device at the start. It must take something around ~40GB. I did I had no problems with the tensors themselves being on the target device before I started processing.
However, due to the size of the dataset I’m loading into the TPU device, I’m suspicious that the TPU device is running out of memory for any of the training operations, such as creating the model parameters that must will be trained. However, I don’t see anything that shows that to be the case to me.
Would anybody here know what it could be?
Error message=
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0
Exception in thread Thread-22 (_loader_worker):
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/parallel_loader.py", line 140, in _loader_worker
_, data = next(data_iter)
File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 384, in __iter__
current_batch = next(dataloader_iter)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 633, in __next__
data = self._next_data()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 677, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer_utils.py", line 704, in __call__
return self.data_collator(features)
File "/content/src/hlm12erc/training/erc_data_collator.py", line 53, in __call__
x_visual = self._visual_to_stacked_tensor([r.visual for r in batch])
File "/content/src/hlm12erc/training/erc_data_collator.py", line 72, in _visual_to_stacked_tensor
print((v.device, v.shape))
RuntimeError: torch_xla/csrc/tensor.cpp:173 : Check failed: data()->tensor_data
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::XLATensor::shape() const
torch_xla::XLATensorImpl::SetupSizeProperties()
torch_xla::XLATensorImpl::sym_sizes_custom() const
THPSize_NewFromSymSizes(at::Tensor const&)
THPVariable_get_shape(THPVariable*, void*)
_PyObject_GenericGetAttrWithDict
PyObject_GetAttr
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
*** End stack trace ***
data collator, visual:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-7649775cb42d> in <cell line: 1>()
----> 1 baseline_model_name, baseline_model_instance = train_model(
2 "baseline",
3 n_epochs=15,
4 batch_size=32,
5 device=tpu)
6 frames
/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py in mark_step(wait)
947 file=sys.stderr,
948 flush=True)
--> 949 torch_xla._XLAC._xla_step_marker(
950 torch_xla._XLAC._xla_get_default_device(), [],
951 wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait))
RuntimeError: torch_xla/csrc/xla_graph_executor.cpp:523 : Check failed: tensor_data
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::XLAGraphExecutor::CollectSyncTensors(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&)
torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20220623::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20220623::Span<std::string const>, bool, bool, bool)
torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyObject_Call
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
PyObject_Call
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***