So I’ve tried to follow this tutorial and I keep getting an error. My csv file has a few changes in the ordering of the columns:
ID
File name
description
target
root_dir (location)
This is my data class I tried to do following the tutorial:
class roofDataset(Dataset):
‘’’
roof data class
'''
def __init__(self, csv_file, transform = None):
'''
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
'''
self.roof_frame = pd.read_csv(csv_file)
self.root_dir = self.roof_frame
self.transform = transform
def __len__(self):
return len(self.roof_frame)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir.iloc[idx,5], self.roof_frame.iloc[idx, 1])
image = io.imread(img_name)
roof = self.roof_frame.iloc[idx, 1:]
roof = np.array([roof])
sample = {'image': image, 'roof': roof}
if self.transform:
sample = self.transform(sample)
return sample
Then I ran the following to test:
roof_dataset = roofDataset(csv_file=‘D:\CIS inspection images 0318\self_build\train\train_roof_images.csv’,
transform = train_transform)
fig = plt.figure()
for i in range(len(roof_dataset)):
sample = roof_dataset[i]
print(i, sample['image'].shape, sample['roof'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
And now I get this error:
TypeError Traceback (most recent call last)
in
5
6 for i in range(len(roof_dataset)):
----> 7 sample = roof_dataset[i]
8
9 print(i, sample[‘image’].shape, sample[‘roof’].shape)
in getitem(self, idx)
23 idx = idx.tolist()
24
—> 25 img_name = os.path.join(self.root_dir.iloc[idx,5], self.roof_frame.iloc[idx, 1])
26 image = io.imread(img_name)
27 roof = self.roof_frame.iloc[idx, 1:]
C:\ProgramData\Anaconda3\lib\ntpath.py in join(path, *paths)
113 return result_drive + result_path
114 except (TypeError, AttributeError, BytesWarning):
→ 115 genericpath._check_arg_types(‘join’, path, *paths)
116 raise
117
C:\ProgramData\Anaconda3\lib\genericpath.py in _check_arg_types(funcname, *args)
147 else:
148 raise TypeError(‘%s() argument must be str or bytes, not %r’ %
→ 149 (funcname, s.class.name)) from None
150 if hasstr and hasbytes:
151 raise TypeError(“Can’t mix strings and bytes in path components”) from None
TypeError: join() argument must be str or bytes, not ‘int64’