multi-GPU model parallelism: device error

Hi,
My model is too large to fit into one GPU, so I split it into two GPUs. Part of the models are as follows:

cur_layer_input = []
for t in range(seq_len):
    d1 = self.down1.cuda(self.device_ids[0])(img_seq_ring[:, t]).cuda(self.device_ids[0])
   d2 = self.down2.cuda(self.device_ids[0])(d1)
   cur_layer_input.append(d2.cuda(self.device_ids[1]))

I used DDP to wrap this model:

from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model, device_ids=[device_ids[0]], find_unused_parameters=True).

However, errors occured:

  File "xxx.py", line 141, in train_model
    loss.backward()
  File "/home/anaconda3/envs/VIBE/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/anaconda3/envs/VIBE/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: grad.device() == bucket_view.device() INTERNAL ASSERT FAILED at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:206, please report a bug to PyTorch. 

I do not know what bucket_view.device() is. So I do not know how to solve this problem.
Could you help me?

Thank you very much.

Hi,

I am not sure how exactly you are trying to use DDP to split your model between devices: DDP usecase is to parallelize training. Could you take a look at model parallel tutorial instead? (you can also combine it with DDP once you set it up), e.g.
https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
it requires you to move inputs between GPUs, not just the models e.g.

x = self.relu(self.net1(x.to('cuda:0'))) 
return self.net2(x.to('cuda:1'))

Yes, I have read the tutorial several times and I have moved inputs to the same device.

I have made some modifications so that the model can be fitted into one GPU for testing.
The question is, when I used a single GPU for data distributed parallel, it worked well.
But when I moved part of the model to another GPU, the error occurred.

I used a BiRNN model. The code is as follows:

def forward(self, img):
	self.reverse_net = self.reverse_net.cuda(self.device_ids[1])
	img_ring = torch.cat([img, img[:,0].unsqueeze(1)], dim=1).cuda(self.device_ids[1])
	seq_len = img_ring.shape[1]  

	cur_layer_input = []
	cur_layer_input_rev = []

	for t in range(seq_len):
		d2 = F.interpolate(img_ring[:, t], scale_factor=1 / 4, mode="trilinear", align_corners=True)
		cur_layer_input.append(d2.cuda(self.device_ids[0]))
		cur_layer_input_rev.append(d2.cuda(self.device_ids[1]))

	cur_layer_input_rev = cur_layer_input_rev[::-1] 
	y_out_fwd = self.forward_net(cur_layer_input)  # forward RNN
	y_out_rev = self.reverse_net(cur_layer_input_rev)  # backward RNN

	y_out_fwd = torch.stack(y_out_fwd, dim=0)
	y_out_rev = torch.stack(y_out_rev, dim=0)  
	y_out_rev = torch.flip(y_out_rev, dims=[1])

	ycat = torch.cat((y_out_fwd.cuda(self.device_ids[1]), y_out_rev), dim=2)

	disp_list = []
	for t in range(1, seq_len-1):
		disp_list.append(self.outconv3.cuda(self.device_ids[1])(ycat[t])) 

	for t in range(len(disp_list)):
		disp_list[t] = F.interpolate(disp_list[t], scale_factor=4, mode="trilinear", align_corners=True)
	return disp_list

The error information is:

  File "xxx.py", line 141, in train_model
    loss.backward()
  File "/home/anaconda3/envs/VIBE/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/anaconda3/envs/VIBE/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: grad.device() == bucket_view.device() INTERNAL ASSERT FAILED at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:206, please report a bug to PyTorch. 

So I want to know what is bucket_view and what may cause the above errors.

It is weird. I can train using one GPU. But when I moved part of the model to another GPU, errors occurred. The only changed thing is the device.

cc @Yanli_Zhao who worked on bucket code. Yanli, could you help here?

@cs123951 for multi-device module, you should not pass device_ids while wrapping your module using DDP.

Try this:

DDP(model, find_unused_parameters=True)

or if it is single-device module, just call torch.cuda.set_device(rank) before wrapping DDP, pass device_ids=[rank]

Thank you for your kind advice. I tried your advice, but problems still happened.

I created two files in my github(https://github.com/cs123951/temp_public_files) to explain my problem.

I constructed a bidirectional RNN model using model parallelism like this.

The difference is that the RNN is bidirectional so I made some modifications.

I distributed the GPU-0 to the first several layers, and the GPU-1 to the last layer.

The forward process is normal. But when it comes to the loss.backward(), error will happen:

RuntimeError: grad.device() == bucket_view.device() INTERNAL ASSERT FAILED at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:206, please report a bug to PyTorch. 

You can reproduce my results using the simple_bug_file2.py in my github by the following command:
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr=xxx.xxx.xx.xx --master_port=31234 simple_bug_file2.py

Don’t forget to change the master address in the command and the visible GPUs in the file.

My environment is

PyTorch version: 1.4.0
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A
GPU: GeForce GTX 1080 Ti

OS: Ubuntu 18.04.3 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.3.2

Python version: 3.7 (64-bit runtime)
Nvidia driver version: 455.28
HIP runtime version: N/A
MIOpen runtime version: N/A

pip 21.0.1

By the way, when I changed down_device = self.device_ids[1] into down_device = self.device_ids[0]. It performs well.

I guess maybe bucket_view is related to the unroll operation of the RNN. But I am not familiar with it. Could you give me some more advice?

Thank you very much.

bucket_view is a copy of parameter’s grad. ‘bucket_view.device=param.device’ was set before training loop starts.

‘bucket_view.device != param.grad.device’ means grad of this param changed device during training loop.

Would you please confirm whether param/grad device changed during training loop?

hi, thank you for your explanation.

I have checked the logic of the model and I am sure that it is fine. Now that the forward pass works well, I can not figure out why there is something wrong with the backward pass.

The error occurs in the file reducer.cpp. I have no idea how to debug the C++ file when I was running a Python program.

Could you give me some advice on how to debug and checking the param/grad device for loss.backward() during the training loop?

Thank you very much!

Hello, I have figured out what mistakes I take!
I make mistakes in the init() of the model like this:

mp_model = ToyMpModel(devices_list).cuda(devices_list[0])
ddp_mp_model = DDP(mp_model, device_ids=[0], find_unused_parameters=True)

I used multi-GPU model, but when I created model, I gave the model only one GPU.
Thus self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1 is False in the distributed.py file.
I changed the GPU in the forward function. Thus triggering the error.

The solution is:
I defined the GPU in the init() function of the model. And then I changed the call function to

mp_model = ToyMpModel(devices_list)
ddp_mp_model = DDP(mp_model, device_ids=[], output_device=[], find_unused_parameters=True)

I hope my mistakes could bring some warnings to the latercommers.
I feel sorry for the confusion caused to you.

This problem can be closed.