Hi, I used ds = torchvision.datasets.CIFAR10(root='./CIFARdata', train=False, download=True, transform=transform)
to load cifar10 dataset, and I found the training was slow and the max cuda utilization was about 50%. Later, I changed to custom data loader below, and found the training was faster and cuda utilization could be ~70%.
def image_preprocess(img, config):
img = img.type(torch.float) / 255
img_amp = img.view(1, 3, img.shape[1], img.shape[2])
img_amp = F.interpolate(img_amp, size=(config.x_num,
config.y_num), mode='bilinear', align_corners=True)
img_padx = (config.total_x_num - config.x_num) // 2
img_pady = (config.total_y_num - config.y_num) // 2
img_amp = F.pad(img_amp, (img_pady, img_pady, img_padx, img_padx))
return img_amp.view(3, config.total_x_num, config.total_y_num)
class Cifar10Dataset(data.Dataset):
def __init__(self, config, is_training=True):
super(Cifar10Dataset, self).__init__()
'''
Args:
is_training: network in training phase or not
'''
self.config = config
self.image_dir = config.image_data_path
self.image_transform = config.image_transform
self.is_training = is_training # training set or test set
self.image_data, self.targets = self._load_data()
print(1)
def _load_data(self):
image_data = []
targets = []
if self.is_training:
image_file = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
else:
image_file = ['test_batch']
for file_name in image_file:
file_path = os.path.join(self.image_dir, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
image_data.append(entry['data'])
if 'labels' in entry:
targets.extend(entry['labels'])
else:
targets.extend(entry['fine_labels'])
image_data = np.vstack(image_data).reshape(-1, 3, 32, 32)
return torch.tensor(image_data).cuda(), torch.tensor(targets).cuda()
def _transformation(self, input_img):
data_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(8)
])
input_img = torch.squeeze(data_transform(torch.unsqueeze(input_img, dim=0)))
return input_img
def __getitem__(self, index):
input_img, target = self.image_data[index], self.targets[index] # Tensor
if self.is_training and self.image_transform is not None:
input_img = self._transformation(input_img)
prop_img = image_preprocess(input_img, self.config)
return prop_img, target
def __len__(self) -> int:
return len(self.image_data)
Currently, I suspect it’s due to the transform
that torchvision.datasets.CIFAR10
uses CPU and custom one uses GPU.
Anyone has some ideas? Thank you!