ztttkx
1
def default_loader(path):
return Image.open(path).convert('L')
class myImageFloder(data.Dataset):
def __init__(self, root, label,train=True,transform=None, target_transform=None, loader=default_loader):
fh = open(label)#实现标签数据的读入
c = 0
imgs = []
class_names = []
for line in fh.readlines():
if c == 0:
class_names = [n.strip() for n in line.rstrip().split(' ')]#claassname为标签的属性名
else:
cls = line.split()
fn = cls.pop(0)
if os.path.isfile(os.path.join(root, fn)):
imgs.append((fn, tuple([float(v) for v in cls])))
c = c + 1
self.root = root
self.imgs = imgs
self.classes = class_names
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.train = train
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return 60000
else:
return 10000
def getName(self):
return self.classes
mytransform = transforms.Compose([transforms.ToTensor()])
train_dataset =myImageFloder(root = "/home/zw/ztfd_data/mnist02/train",
label = "/home/zw/ztfd_data/mnist02/train.txt",
train=True,transform = mytransform)
test_dataset = myImageFloder(root = "/home/zw/ztfd_data/mnist02/test",
label = "/home/zw/ztfd_data/mnist02/test.txt",
train=False,transform = mytransform)
print(len(train_dataset))
print(len(test_dataset))
train_loader=torch.utils.data.DataLoader(
myImageFloder(root = "/home/zw/ztfd_data/mnist02/train",
label = "/home/zw/ztfd_data/mnist02/train.txt",transform=mytransform),
batch_size=64,shuffle=True)
test_loader=torch.utils.data.DataLoader(
myImageFloder(root = "/home/zw/ztfd_data/mnist02/test",
label = "/home/zw/ztfd_data/mnist02/test.txt", transform=mytransform),
batch_size=64, shuffle=True)
print(len(train_loader))
strong textprint(len(test_loader))