I’m trying to load in a dataset for super-resolution and I have set up two functions which use Compose to crop and resize the images.
The function I have created for the input images works correctly and they are outputting as expected. The transform function for the target images is basically identical, just omitting the resize part of it.
def input_trans(c_size, sF):
return Compose([
CenterCrop(c_size),
Resize(c_size // sF),
ToTensor(),
])
def goal_trans(c_size):
return Compose([
CenterCrop(c_size),
ToTensor(),
])
These functions are used in my dataset class when the images are loaded. I originally had goal = input.Copy() but I have changed it so both input and goal load the image separately. (was testing if the .copy() was the issue
def __getitem__(self, idx):
input = Image.open(self.image_filenames[idx]).convert('RGB')
goal = Image.open(self.image_filenames[idx]).convert('RGB')
if self.input_transform:
input = self.input_transform(input)
if self.goal_transform:
print(goal)
print(goal.size)
goal = self.goal_transform(goal)
return input, goal
The error I receive is the following:
Traceback (most recent call last):
File "main.py", line 128, in <module>
main() # execute this only when run directly, not when imported!
File "main.py", line 55, in main
train_model(epoch)
File "main.py", line 40, in train_model
for data_item, batch in enumerate(training_data_loader):
File "C:\Users\[NAME]\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
data = self._next_data()
File "C:\Users\[NAME]\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 385, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "C:\Users\[NAME]\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\Users\[NAME]\anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "main.py", line 118, in __getitem__
goal = self.goal_transform(goal)
File "C:\Users\[NAME]\anaconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 70, in __call__
img = t(img)
TypeError: ToTensor() takes no arguments
Confuses me because it doesn’t seem to have a problem with the first transformation (Ive checked and it does output before crashing).
I would appreciate any help you guys can give,
Thanks