Hey,
I’m using the segmentation models pytorch repo (GitHub - qubvel/segmentation_models.pytorch: Segmentation models with pretrained backbones. PyTorch.), and I’m trying to use a 4 channel input image instead of 3 (which most of the code I’m using seems to expect). I’m getting this issue and I’m stumped on how to debug it. Anyone have any ideas? Anywhere I should go print tensor shapes?
Traceback (most recent call last):
File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/bhaktatejas922/internal-geometry-ml/src/train.py", line 188, in <module>
main(cfg)
File "/home/bhaktatejas922/internal-geometry-ml/src/train.py", line 176, in main
**cfg.training.fit,
File "/home/bhaktatejas922/internal-geometry-ml/src/training/runner.py", line 257, in fit
for i, batch in enumerate(train_dataloader):
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 801, in __next__
return self._process_data(data)
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 846, in _process_data
data.reraise()
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/_utils.py", line 369, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 1.
Original Traceback (most recent call last):
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 75, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 75, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 65, in default_collate
return default_collate([torch.as_tensor(b) for b in batch])
File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689