having problem with multi-gpu in pytorch transformer

I am currently trying to make a translation model with Trnasformer model through PyTorch. Since I have 2 GPUs (2080ti x 2) available for training, I want to train the model through multi-gpu. Currently, the gpu is assigned to 0 and 1 respectively. The way I use multi-gpu is to put nn.DataParallel on the model object.

Declared encoder, decoder, and model objects

enc = Encoder(INPUT_DIM, HIDDEN_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device)
dec = Decoder(OUTPUT_DIM, HIDDEN_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device)
model = nn.DataParallel(Transformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device))

Since the above method has been used in various models of pytorch in the past, I thought it would work without any problems. However, if I go through the actual training in this way, I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [22], line 87
     84 for epoch in range(N_EPOCHS):
     85     start_time = time.time() # 시작 시간 기록
---> 87     train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
     88     valid_loss = evaluate(model, validation_iterator, criterion)
     90     end_time = time.time() # 종료 시간 기록

Cell In [18], line 15, in train(model, iterator, optimizer, criterion, clip)
     11 optimizer.zero_grad()
     13 # 출력 단어의 마지막 인덱스(<eos>)는 제외
     14 # 입력을 할 때는 <sos>부터 시작하도록 처리
---> 15 output, _ = model(src, trg[:,:-1])
     17 # output: [배치 크기, trg_len - 1, output_dim]
     18 # trg: [배치 크기, trg_len]
     20 output_dim = output.shape[-1]

File ~/anaconda3/envs/jki_pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
...
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_203464/284771533.py", line 31, in forward
    src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

As you can see from the above error, the part where the error occurred was line 15 of the train() function. The train function is defined as:

def train(model, iterator, optimizer, criterion, clip):
    model.train() # 학습 모드
    epoch_loss = 0

    for i, batch in enumerate(tqdm(iterator, ascii=True, ncols=100)):
        src = batch.src
        trg = batch.trg

        optimizer.zero_grad()

        output, _ = model(src, trg[:,:-1])

        output_dim = output.shape[-1]

        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

To solve this error, I tried nn.DataParallel in the encoder decoder, but the same error is returned. I also tried the model in the train() function.
Also, I tracked the variables to see where the Tensors are assigned to which cuda, but only cuda, not cuda:0, cuda:1. I’ve tried so many different things and still can’t find this error. If anyone knows please help me

Please let me know if you have any additional code I need to show you.

Can you show your forward pass?

The forward pass is where you’ll want to make sure to move any encoder layer outputs to the device of the decoder layer, before continuing.

Hi, first of all, thank you so much for replying. I’m so desperately trying to solve this problem

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        print("Transformer")

        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        print("Transformer.make_src_mask")

        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        return src_mask

    def make_trg_mask(self, trg):
        print("Transformer.make_trg_mask")

        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)

        trg_len = trg.shape[1]

        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()

        trg_mask = trg_pad_mask & trg_sub_mask

        return trg_mask

    def forward(self, src, trg):
        print("Transformer.forward")

        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        enc_src = self.encoder(src, src_mask)

        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)

        return output, attention

The last declared part in the Transformer class(mm.Module): is the part that passes from the encoder to the decoder.

I’m not going to try to figure out what the exact issue is you’re having with DataParallel. Instead, I’ll give you something you can try that might help you accomplish your objective:

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device1, device2):
        super().__init__()
        print("Transformer")

        self.encoder = encoder.to(device1)
        self.decoder = decoder.to(device2)
        self.src_pad_idx = src_pad_idx.to(device1)
        self.trg_pad_idx = trg_pad_idx.to(device2)
        self.device1 = device1
        self.device2 = device2

    def make_src_mask(self, src):
        print("Transformer.make_src_mask")

        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        return src_mask

    def make_trg_mask(self, trg):
        print("Transformer.make_trg_mask")

        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) # ?? What device?

        trg_len = trg.shape[1]

        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device2)).bool()

        trg_mask = trg_pad_mask & trg_sub_mask

        return trg_mask

    def forward(self, src, trg):
        print("Transformer.forward")

        src_mask = self.make_src_mask(src.to(self.device1))
        trg_mask = self.make_trg_mask(trg.to(self.device2))

        enc_src = self.encoder(src, src_mask)

        output, attention = self.decoder(trg, enc_src.to(self.device2), trg_mask, src_mask.to(self.device2))

        return output, attention

Not sure if you’ve assigned all of your tensors to self.device in the above script. You can print out the device if you aren’t sure:

print(src.get_device())

Just make sure they are all assigned properly.