I overrode the original transforms.Compose
(vision/transforms.py at master · pytorch/vision · GitHub for preprocessing multiple targets, I found the type of img1 seemed turned to be class ‘torch.Tensor’ rather than normal class ‘PIL.Image.Image’ like img2. This seemed weird. Does anyone know the reason for this? Thanks!
Here’ my code,
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img1, img2, mask1,mask2):
print(type(img1))
print(type(img2))
print(type(mask1))
print(type(mask2))
print("img1.size",img1.size)
print("mask1.size",mask1.size)
print("mask2.size",mask2.size)
assert img1.size() == mask1.size
assert img1.size() == mask2.size
for t in self.transforms:
img1, img2,mask1,mask2 = t(img1,img2,mask1,mask2)
return img1, img2,mask1,mask2
Then I got ,
type(img1) <class 'torch.Tensor'>
type(img2) <class 'PIL.Image.Image'>
type(mask1) <class 'PIL.Image.Image'>
type(mask2) <class 'PIL.Image.Image'>
img1.size() <built-in method size of Tensor object at 0x2b18bf80eb80>
mask1.size() (876, 376)
mask2.size() (876, 376)
...
Traceback (most recent call last):
File "train.py", line 210, in <module>
main()
File "train.py", line 110, in main
train(net, optimizer)
File "train.py", line 117, in train
img1.size() <built-in method size of Tensor object at 0x2b18bf7f3d00>
mask1.size() (550, 488)
mask2.size() (550, 488)
for i, data in enumerate(train_loader):
File "/home/public/software/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 363, in __next__
data = self._next_data()
File "/home/public/software/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 989, in _next_data
return self._process_data(data)
File "/home/public/software/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1014, in _process_data
data.reraise()
File "/home/public/software/anaconda3/lib/python3.8/site-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/public/software/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
data = fetcher.fetch(index)
File "/home/public/software/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/public/software/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/public/MyModel/datasets.py", line 75, in __getitem__
img1, img2, mask1, mask2 = self.joint_transform(img1, img2, mask1, mask2)
File "/home/public/MyModel/joint_transforms.py", line 16, in __call__
assert img1.size() == mask1.size
AssertionError