Hi everyone,
I’m trying to implement the U2Net (GitHub - xuebinqin/U-2-Net: The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection.") on my custom data. The code was run on 24 epochs until I got the error:
Traceback (most recent call last):
File "C:\Users\Admin\U-2-Net-master\u2net_train.py", line 318, in <module>
train_loss, valid_loss = train_model(net, salobj_dataloader_train, salobj_dataloader_valid, batch_size_train, patience, epoch_num, train_num=len(X_train))
File "C:\Users\Admin\U-2-Net-master\u2net_train.py", line 203, in train_model
for i, data in enumerate(salobj_dataloader_train):
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\dataloader.py", line 634, in __next__
data = self._next_data()
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\dataloader.py", line 1346, in _next_data
return self._process_data(data)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\dataloader.py", line 1372, in _process_data
data.reraise()
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\_utils.py", line 644, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\fetch.py", line 54, in fetch
return self.collate_fn(data)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\collate.py", line 264, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\collate.py", line 127, in collate
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\collate.py", line 127, in <dictcomp>
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\collate.py", line 119, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File "C:\Users\Admin\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\collate.py", line 162, in collate_tensor_fn
return torch.stack(batch, 0, out=out)
RuntimeError: torch.cat(): input types can't be cast to the desired output type Byte
The dataset contains image and its mask like this:
My code is running on GPU. After trying to run all the code again, I got the same issue after 5 epochs. So, I have 2 questions:
- Why does the error appear at different times? (5th epoch and 24th epoch)
- I’ve been trying to fix this but I am still stuck and I’m running out of ideas on how to figure out what causes this error. How I can approach this error?
Thank you very much!!!