RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper__index_select) while using Dataparallel class

I have a Bert encoder model as:

class ClassifierBert(nn.Module):
    def __init__(self, tgt_size, con_bert):
        super(ClassifierBert, self).__init__()
        self.bert_num_tokens = con_bert.vocab_size
        self.bert_hidden_layers = con_bert.hidden_size
        self.b_encoder = custom_bert(con_bert, pretrained=True)
        self.classifier = nn.Linear(self.bert_hidden_layers*MAX_LENGTH, tgt_size)
    

    def forward(self, input):
        features = self.b_encoder(**input).last_hidden_state
        features = torch.flatten(features, 1)
        outputs = self.classifier(features)
        return outputs

and its my device:

DEVICE = torch.device('cuda:0')

and training module:

con = BertConfig(vocab_size=BERT_NUM_TOKEN, num_hidden_layers=NUM_BERT_LAYERS)

sentence_transformer = ClassifierBert(TGT_VOCAB_SIZE, con)
sentence_transformer = nn.DataParallel(sentence_transformer)
loss_fn = nn.BCEWithLogitsLoss(weight=weights)

optimizer = torch.optim.Adam(sentence_transformer.parameters(), lr=LR, 
    betas=(0.9, 0.98), eps=1e-9)

    
    
train()

def train(num_epochs=NUM_EPOCHS):
    BEST_PERCISION = 0.00
    BEST_WTS = copy.deepcopy(sentence_transformer.state_dict())

    for epoch in range(1, num_epochs+1):
        start_time = timer()
        train_loss, train_ap, train_lrap = train_val_epoch(sentence_transformer, 
            optimizer, loss_fn, mode='train')
        val_loss, val_ap, val_lrap = train_val_epoch(sentence_transformer, 
            optimizer, loss_fn, mode='val')
        if val_lrap > BEST_PERCISION:
            BEST_PERCISION = val_lrap
            BEST_WTS = copy.deepcopy(sentence_transformer.state_dict())
        val_ap *= 100
        train_ap *= 100
        val_lrap *= 100
        train_lrap *= 100
        end_time = timer()
        time_passed = int(end_time - start_time)
        print("******************************************************************")
        print(f"Epoch: {epoch}")
        print((f"Train_Loss: {train_loss:.4f}, Val_Loss: {val_loss:.4f},\
            Train_AP: {train_ap:.2f}, Train_LRAP: {train_lrap:.2f},\
            Val_AP: {val_ap:.2f}, Val_LRAP: {val_lrap:.2f}, \
            "f"Epoch_Time = {(time_passed//60) :d} min and {(time_passed%60) :d} sec"))
    

    torch.save(BEST_WTS, "%s/weights/sentence_transformer_%.2f.pt" % (PATH, BEST_PERCISION*100))

def train_val_epoch(model, optimizer, loss_function, mode):
    if mode == 'train':
        model.train()
        data_loader = DataLoader(train_iterator, batch_size=BATCH_SIZE, 
            collate_fn=collate_fn, shuffle=True)
    if mode == 'val':
        model.eval()
        data_loader = DataLoader(val_iterator, batch_size=BATCH_SIZE, 
            collate_fn=collate_fn)
    

    losses = 0.00
    ap, lrap = 0.00, 0.00

    model = model.to(DEVICE)
    
    for src, tgt in data_loader:
        src = src.to(DEVICE)
        outputs = model(src)
        out_shape = outputs.shape[1]
        
        tgt = build_target(tgt, out_shape).to(0)

        optimizer.zero_grad()

        loss = loss_function(outputs, tgt)

        if mode == 'train':
            loss.mean().backward()
            optimizer.step()

        losses += loss.mean().item()

        ap_batch = average_precision_score(tgt.cpu().detach().numpy().astype(int),\
             outputs.cpu().detach().numpy(), average='samples')
        lrap_batch = label_ranking_average_precision_score(tgt.cpu().detach().numpy().astype(int),\
             outputs.cpu().detach().numpy())
        ap += ap_batch
        lrap += lrap_batch
    
    return losses / len(data_loader), ap / len(data_loader), lrap / len(data_loader)

running this I get the error:

Traceback (most recent call last):
  File "transformer_classifier.py", line 258, in <module>
    train()
  File "transformer_classifier.py", line 162, in train
    optimizer, loss_fn, mode='train')
  File "transformer_classifier.py", line 91, in train_val_epoch
    outputs = model(src)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "transformer_classifier.py", line 55, in forward
    features = self.b_encoder(**input).last_hidden_state
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py", line 994, in forward
    past_key_values_length=past_key_values_length,
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/transformers/models/bert/modeling_bert.py", line 214, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 160, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 2044, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper__index_select)

what am I doing wrong!!!

Looks like embeddings and indices are on two different devices? In your model’s forward function, can you print the deivces of self.b_encoder.word_embeddings and the inputs? I assume one of those two are not moved to the correct device.

hey thanks for you reply, I don’t think that one could print the device of input and embeddings in pytorch, or at least I don’t know, do you know how?

It is easy to check the device where tensors are allocated by xxxx.device

You should check whether custom_bert is defined under nn.Module or not.
model.to(DEVICE) is convenient method to allocate your model to the defined device but
if there is a non-parameter-based tensor such as torch.zeros, you should manually allocate it to your device.

1 Like

could you elaborate more? what do you mean by non-parameter based tensor?

I have changed custom_bert to BertModel from hugging face that are definitely defined under nn.Module, the bug persists to exist

I explained the concept of parameter in weird way…

I meant a tensor which is not inherit nn such as torch.zeros

Anyway, I will take a look for finding an error
Thx

ok yes I thought so too, but it feel like to be some sort of a bug for this class, eventually could solve it by completely abandoning the nn.DataParallel and using nn.parallel.DistributedDataParallel instaed. It was little bit painful and not as easy as DataParallel, but happy that did it: nn.parallel.DistributedDataParallel seems to be more concise.

2 Likes