AttributeError Traceback (most recent call last)
in ()
28 dataset = IF(root=data_root, transform=torchvision.transforms.ToTensor())
29 loader = data_utils.DataLoader(dataset, batch_size=5,shuffle=True)
—> 30 train_dataset, test_dataset = train_test_split(dataset, .2)
31 trainloader = data_utils.DataLoader(train_dataset, batch_size=20, shuffle=True)
32 testloader = data_utils.DataLoader(test_dataset, batch_size=20, shuffle=True)
in train_test_split(dataset, test_size)
15 train_dataset = copy.deepcopy(dataset)
16 test_dataset = copy.deepcopy(dataset)
—> 17 total_n = train_dataset.len()
18 rand_perm = permutation(total_n)
19 cutoff = int(test_size * total_n)
AttributeError: ‘ImageFolder’ object has no attribute ‘len’
I am trying to implemen below code for lerning pytorch (loaded file to use in a classifier) where the input is jpg images.
[quote=“farrokhi, post:1, topic:1859, full:true”]
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torch.nn.functional as F
import copy
from numpy.random import permutation
def train_test_split(dataset, test_size):
train_dataset = copy.deepcopy(dataset)
test_dataset = copy.deepcopy(dataset)
total_n = train_dataset.len()
rand_perm = permutation(total_n)
cutoff = int(test_size * total_n)
test_dataset.imgs = [dataset.imgs[rand_perm[i]] for i in range(0, cutoff)]
train_dataset.imgs = [dataset.imgs[rand_perm[i]] for i in range(cutoff, total_n)]
return train_dataset, test_dataset
#great dataset/loader for train and test
from torchvision.datasets import ImageFolder as IF
import torchvision
import torch.utils.data as data_utils
data_root = './Genuine/'
dataset = IF(root=data_root, transform=torchvision.transforms.ToTensor())
loader = data_utils.DataLoader(dataset, batch_size=5,shuffle=True)
train_dataset, test_dataset = train_test_split(dataset, .2)
trainloader = data_utils.DataLoader(train_dataset, batch_size=20, shuffle=True)
testloader = data_utils.DataLoader(test_dataset, batch_size=20, shuffle=True)