Dear all,
I am trying to parallelise my model to train it on 2 GPUs. I have decided to use the nn.DataParallel function, but I received the following error during the forward pass:
Traceback (most recent call last):
File “main.py”, line 93, in
args.func(args)
File “/home/mdatres/new_quantlab/quantlab/manager/flows/train.py”, line 243, in train
ypr = net(x)
File “/data2/mdatres/miniconda3/envs/quantlab/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/data2/mdatres/miniconda3/envs/quantlab/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/data2/mdatres/miniconda3/envs/quantlab/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/data2/mdatres/miniconda3/envs/quantlab/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/data2/mdatres/miniconda3/envs/quantlab/lib/python3.8/site-packages/torch/_utils.py”, line 425, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File “/data2/mdatres/miniconda3/envs/quantlab/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(*input, **kwargs)
File “/data2/mdatres/miniconda3/envs/quantlab/lib/python3.8/site-packages/torch/fx/graph_module.py”, line 513, in wrapped_call
raise e.with_traceback(None)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
I used the resnet50 model which is already implemented in torchvision.
class ResNet(nn.Module):
def __init__(self, num_classes = 2, pretrained = True, seed = -1):
super().__init__()
self.net1 = resnet50(pretrained=pretrained)
self.net1.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True)
self.softmax = nn.Softmax()
# if not pretrained:
# self.net1._initialize_weights(seed=seed)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net1(x)
return self.softmax(x)
Below I attach the code that I have used to load the model on the GPUs and to train it:
# cycle over epochs (one loop for each fold)
for epoch_id in range(start_epoch_id, logbook.n_epochs):
# === EPOCH: START ===
# === TRAINING STEP: START ===
net.train()
# master-workers synchronisation point: quantization controllers might change the network's quantization parameters stochastically
if (not platform.is_horovod_run) or platform.is_master:
for c in qnt_ctrls:
c.step_pre_training_epoch(epoch_id)
if platform.is_horovod_run:
platform.hvd.broadcast_parameters(net.state_dict(), root_rank=platform.master_rank)
# cycle over batches of training data (one loop for each epoch)
for batch_id, (x, ygt) in enumerate(train_loader):
# master-workers synchronisation point: quantization controllers might change the network's quantization parameters stochastically
# TODO: in multi-process runs, synchronising processes at each step might be too costly
if (not platform.is_horovod_run) or platform.is_master:
for c in qnt_ctrls:
c.step_pre_training_batch()
if platform.is_horovod_run:
platform.hvd.broadcast_parameters(net.state_dict(), root_rank=platform.master_rank)
# event: forward pass is beginning
train_meter.step(epoch_id, batch_id)
train_meter.start_observing()
train_meter.tic()
# processing (forward pass)
x = x.to(platform.device)
ypr = net(x)
# loss evaluation
ygt = ygt.to(platform.device)
loss = loss_fn(ypr, ygt)
# event: forward pass has ended; backward pass is beginning
train_meter.update(ygt, ypr, loss)
# training (backward pass)
gd.opt.zero_grad() # clear gradients
loss.backward() # gradient computation
gd.opt.step() # gradient descent
# event: backward pass has ended
train_meter.toc(ygt)
train_meter.stop_observing()
In what’s follow the variable platform.device is “cuda:0”.
net = net.to(platform.device)
if platform.is_nndataparallel_run:
net = nn.DataParallel(net) # single-node, single-process, multi-GPU run
Thanks for your help.