I got this error while trying to test CNN model.
I already checked type about this error point’variable.
Here is error point
10. imgs = imgs.to(device) #imgs type <class ‘torch.Tensor’>
—> 11 labels = labels.to(device) #labels type <class ‘torch.Tensor’>
AttributeError: ‘tuple’ object has no attribute ‘to’
Both are Tensor type, there is not tuple type.
I try to make image classification to use multiple folders
ex)
folderA > folderB> folderC
CatDataSet -Train - White
- Black
- Brown
-Test - White
- Black
- Brown
-Val - White
- Black
- Brown
Here is my test function
def test(model, data_loader, device):
print('Start test..')
model.eval()
with torch.no_grad():
correct = 0
total = 0
for i, (imgs, labels) in enumerate(data_loader):
imgs = imgs.to(device) #imgs type <class 'torch.Tensor'>
print("first labels",type(label))
labels = labels.to(device)
print("second labels",type(label))
outputs = model(imgs)
print(outputs)
_, argmax = torch.max(outputs, 1)
total += imgs.size(0)
correct += (labels == argmax).sum().item()
print('Test accuracy for {} images: {:.2f}%'.format(total, correct / total * 100))
model.train()
Here is my CNN function
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32,64,3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64,128,3),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128,128,3),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc1 = nn.Linear(128*5*5,512)
self.fc2 = nn.Linear(512,3)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
Here is DataSet function
class CatDataset(Dataset):
def __init__(self, data_dir, mode, transform=None):
self.all_data = sorted(glob.glob(os.path.join(data_dir, mode, '*', '*')))
self.transform = transform
def __getitem__(self, index):
data_path = self.all_data[index]
imgs = Image.open(data_path)
imgs = self.transform(imgs)
label = os.path.basename(data_path)
print("data_path label",type(label)) #label type : str
if label.startswith("Black"):
label=0
elif label.startswith("White"):
label=1
elif label.startswith("Brown"):
label=2
print("label",type(label)) #label type : int
return imgs, label
def __len__(self):
length = len(self.all_data)
return length
Please help me where did I wrote wrong code.