Hi,
I would like to add GPUs to different parts of my code.I am extracting features from several different magnifications of the same image, however using 1 GPU is quite a slow process. I was wondering whether there is a simple way of speeding this up, perhaps by applying different GPU devices for each input? I’m unsure of how to proceed…
Check out my code below:
I have simpliefied it by only adding two magnifications (20x and 40x).
In the for loop, the inputs are put into the available device and features are then extracted. I then concatenate the two output features, where they are then put onto a cpu before I append them to the H5PY dataset (put on to the cpu because of being a numpy array).
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with h5py.File(path, mode='r+') as hdf5_file:
array_40 = hdf5_file[f'{phase}_40x_arrays']
array_20 = hdf5_file[f'{phase}_20x_arrays']
array_all = hdf5_file[f'{phase}_all_arrays']
array_labels = hdf5_file[f'{phase}_labels']
array_batch_idx = hdf5_file[f'{phase}_batch_idx']
array_paths = hdf5_file[f'{phase}_img_paths']
batch_idx = int(array_batch_idx[0]+1)
print("Batch ID is restarting from {}".format(batch_idx))
dataloaders_dict = torch.utils.data.DataLoader(datasets_dict, batch_size=args.batch_size, sampler=SequentialSampler2(
datasets_dict, batch_idx), num_workers=args.num_workers, shuffle=False)
for i, (inputs40x, inputs20x, paths40x, paths20x, labels) in enumerate(dataloaders_dict):
print(f'Batch ID: {batch_idx}')
inputs40x = inputs40x.to(device)
inputs20x = inputs20x.to(device)
labels = labels.to(device)
paths = paths40x
# delete the last fc layer.
modules = list(resnet50.children())[:-1]
resnet = nn.Sequential(*modules)
x40 = resnet(inputs40x)
x20 = resnet(inputs20x)
x_all = torch.cat([x40, x20], dim=1)
# add to index
array_40[batch_idx, ...] = x40.cpu()
array_20[batch_idx, ...] = x20.cpu()
array_all[batch_idx, ...] = x_all.cpu()
array_labels[batch_idx, ...] = labels[:].cpu()
array_batch_idx[:,...] = batch_idx
array_paths[batch_idx, ...] = paths
batch_idx +=1