My goal is to train a pre-trained object segmentation model using my own dataset with its own classes. So, I created my own dataset using the COCO Dataset format.
Here below, you can see that I am trying to create a Dataset using the function CocoDetection.
import torchvision.datasets as dset
def get_transform():
custom_transforms = []
custom_transforms.append(T.ToTensor())
return T.Compose(custom_transforms)
coco_test = dset.CocoDetection(root = root_dir,
annFile = test_json,
transforms=get_transform())
# Test if dataset is working
for i in range(len(coco_train)):
sample = coco_train[i]
print(sample[0])
if i == 3:
break
At first, I tried without applying any transformations at all and it worked. However, after I added a transform on the function, it did not work. I checked the error and it says that the call takes 2 positional argument, but 3 were given. But, I inspected the source code, and I am giving it 2 arguments. I don’t know what’s the problem. I don’t know what to do next. Please help. Here’s the full error. By the way, I am using Google Colab.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-13-bc0dddb4b6c5> in <module>()
1 for i in range(len(coco_train)):
----> 2 sample = coco_train[i]
3 print(sample[0])
4
5 if i == 3:
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/coco.py in __getitem__(self, index)
116 img = Image.open(os.path.join(self.root, path)).convert('RGB')
117 if self.transforms is not None:
--> 118 img, target = self.transforms(img, target)
119
120 return img, target
TypeError: __call__() takes 2 positional arguments but 3 were given
From the docs for CocoDetection
:
- transform ( callable , optional ) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.ToTensor
- target_transform ( callable , optional ) – A function/transform that takes in the target and transforms it.
- transforms ( callable , optional ) – A function/transform that takes input sample and its target as entry and returns a transformed version.
You could either pass the transform
and target_transform
separately or use a transforms
function, which accepts the image and target.
I don’t get it. So, is there something wrong to what I did? I read other codes and they also created a list for transforms too. And why would the number of positional arguments I passed be 3, when it is clearly two? Do you think maybe it is just not maintained to the latest version?
In other code snippets the transformation was most likely only applied to the data, so a single transform
argument was passed.
For the CocoDetection
dataset, you are transforming the data and target, so you should either pass the transformations separately or provide a transformation function, which accepts the data and the target as arguments.
The self
argument is also counted in Python.
No, I don’t think so.
EDIT: Note that these arguments are optional, so you can also pass your transformation as transform=get_transform()
only to skip the target transformation.
1 Like
Ohh right, the self is also counted as an argument. Thanks for the reminder!
Anyway, I just solved my problem. It appears that I am using the wrong CocoDetection
Dataset class. I orginally used:
import torchvision.datasets as dset
# and then
coco_test = dset.CocoDetection(root = root_dir,
annFile = test_json,
transforms=get_transform())
When what I should’ve used is:
from torchvision.datasets.coco import CocoDetection
# and then
coco_test = CocoDetection(root = root_dir,
annFile = test_json,
transforms=get_transform())
And it worked!
So, it appears that there are two CocoDetection
classes. One in torchvision.datasets
, another in torchvision.datasets.coco
. Should I inform them about this duplicate code? And if so, how should I inform them about this?
Both classes are pointing to the same implementation and I’m wondering how you solved the issue, but it’s good to hear it’s working.
I think it may have something to do with the first implementation inheriting the class VisionDataset
. While on the second one, the implementation inherits the data.Datasets
.