How does the dataloader create a batch of tensors? Is it simply doing:
torch.cat([tensor_1, tensor_2])
?
Trying to understand whether there is more optimization than that or not.
How does the dataloader create a batch of tensors? Is it simply doing:
torch.cat([tensor_1, tensor_2])
?
Trying to understand whether there is more optimization than that or not.
Depending on the type your __getitem__
method from the Dataset
returns, the default_collate
will use torch.stack
, create a torch.tensor
etc.
Hi @ptrblck, we want to put some JSON parsing logic into dataloader to leverage multiple worker capabity, can we have return type just a dict? it seems if we use former 1.0.0 version it will return pure dict, but later version will create tensor?
here is our __get_item__
with open(current_file, mode='rb') as f:
text = f.read().decode('utf-8')
all_data.extend(text.split('\n'))
json_data = []
for line in all_data:
try:
json_data.append(json.loads(line))
except:
break
return json_data
It seems you could still return a dict
.
At least this small code snippet works in the nightly binary from approx. a week ago:
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
return {'x': torch.randn(3, 24, 24), 'y': torch.randint(0, 10, (1,))}
def __len__(self):
return 10
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=5, num_workers=0)
for batch in loader:
print(batch['x'].shape)
print(batch['y'].shape)
Are you getting an error and if so, could you post the error message?
Thanks @ptrblck
My error msg is as:
ERROR: Unexpected segmentation fault encountered in worker.
Traceback (most recent call last):
File "/home/miniconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 480, in _try_get_batch
data = self.data_queue.get(timeout=timeout)
File "/home/miniconda/lib/python3.6/multiprocessing/queues.py", line 104, in get
if not self._poll(timeout):
File "/home/miniconda/lib/python3.6/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
File "/home/miniconda/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
r = wait([self], timeout)
File "/home/miniconda/lib/python3.6/multiprocessing/connection.py", line 911, in wait
ready = selector.select(timeout)
File "/home/miniconda/lib/python3.6/selectors.py", line 376, in select
fd_event_list = self._poll.poll(timeout)
File "/home/miniconda/lib/python3.6/site-packages/torch/utils/data/_utils/signal_handling.py", line 65, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 95106) is killed by signal: Segmentation fault.
One interest thing is
when I just parse the data line by line(each line is JSON string), I do not have this issue:
with open(current_file, mode='rb') as f:
text = f.read().decode('utf-8')
all_data.extend(text.split('\n'))
but if I add a JSON parse logic after read line by line , it will report this error
with open(current_file, mode='rb') as f:
text = f.read().decode('utf-8')
all_data.extend(text.split('\n'))
json_data = []
for line in all_data:
try:
json_data.append(json.loads(line))
except:
break
return json_data
searched around and found out multiple github link pointed to this is due to lack of shm, however I have sufficient shm size, I understand there will be some JSON memory overhead, but even I decrease the number of worker into 2, and data set is very small, it still have same problem. I kind of doubt it is shm related. any clue?
Also another interesting part is when I directly return JSON.loads item as __getitem__
, it auto convert to tensor type ,is it something new pytorch dataloader automatically do?