Input and model on different GPUs Error with complex model when using nn.DataParallel

Error message:
Expected tensor for ‘out’ to have the same device as tensor for argument #2 ‘mat1’; but device 1 does not equal 0 (while checki
ng arguments for addmm)

I understand this error has been discussed quite a lot and after reading several posts I had a basic idea of why this occurs on my code. Mainly because I am using a very complex model.

My code works fine on single-GPU mode, after adding torch.nn.DataParallel, I tried to run on a 4-GPU node, the error occurred. Can someone kindly have a look at my model and point out where to modify please?

CUDA Setting:

os.environ["CUDA_VISIBLE_DEVICES"]= '0,1,2,3'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Input variable CUDA setting:

PE_batch = get_pe(seq_lens, seq_len).float().to(device)
seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)
state_pad = torch.zeros([matrix_reps_batch.shape[0],seq_len, seq_len]).to(device)

Model instantiation and application:

contact_net = ContactAttention_simple_fix_PE(d=d, L=seq_len, device=device).to(device)
contact_net = torch.nn.DataParallel(contact_net)
output = contact_net(PE_batch,seq_embedding_batch, state_pad)

Models details (problem should be here, I am using a subcalss of a class from nn.Module,so there are two model classes. During my debug, I have added .to(self.device) to every operation in forward() in case I miss any of the layers):

class ContactAttention_simple(nn.Module):
    def __init__(self, d,L):
        super(ContactAttention_simple, self).__init__()
        self.d = d
        self.L = L
        self.conv1d1= nn.Conv1d(in_channels=4, out_channels=d,
            kernel_size=9, padding=8, dilation=2)
        self.bn1 = nn.BatchNorm1d(d)

        self.conv_test_1 = nn.Conv2d(in_channels=6*d, out_channels=d, kernel_size=1)
        self.bn_conv_1 = nn.BatchNorm2d(d)
        self.conv_test_2 = nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1)
        self.bn_conv_2 = nn.BatchNorm2d(d)
        self.conv_test_3 = nn.Conv2d(in_channels=d, out_channels=1, kernel_size=1)

        self.position_embedding_1d = nn.Parameter(
            torch.randn(1, d, 600)
        )
        self.encoder_layer = nn.TransformerEncoderLayer(2*d, 2)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 3)

    def forward(self, prior, seq, state):

        position_embeds = self.position_embedding_1d.repeat(seq.shape[0],1,1)
        seq = seq.permute(0, 2, 1) # 4*L
        seq = F.relu(self.bn1(self.conv1d1(seq))) #d*L just for increase the capacity

        seq = torch.cat([seq, position_embeds], 1) # 2d*L
        seq = self.transformer_encoder(seq.permute(-1, 0, 1))
        seq = seq.permute(1, 2, 0)
        seq_mat = self.matrix_rep(seq) # 4d*L*L

        p_mat = self.matrix_rep(position_embeds) # 2d*L*L

        infor = torch.cat([seq_mat, p_mat], 1) # 6d*L*L

        contact = F.relu(self.bn_conv_1(self.conv_test_1(infor)))
        contact = F.relu(self.bn_conv_2(self.conv_test_2(contact)))
        contact = self.conv_test_3(contact)

        contact = contact.view(-1, self.L, self.L)
        contact = (contact+torch.transpose(contact, -1, -2))/2

        return contact.view(-1, self.L, self.L)

    def matrix_rep(self, x):
        x = x.permute(0, 2, 1) # L*d
        L = x.shape[1]
        x2 = x
        x = x.unsqueeze(1)
        x2 = x2.unsqueeze(2)
        x = x.repeat(1, L,1,1)
        x2 = x2.repeat(1, 1, L,1)
        mat = torch.cat([x,x2],-1) # L*L*2d
        mat_tril = torch.tril(mat.permute(0, -1, 1, 2)) # 2d*L*L
        mat_diag = mat_tril - torch.tril(mat.permute(0, -1, 1, 2), diagonal=-1)
        mat = mat_tril + torch.transpose(mat_tril, -2, -1) - mat_diag
        return mat


class ContactAttention_simple_fix_PE(ContactAttention_simple):
    def __init__(self, d, L, device):
        super(ContactAttention_simple_fix_PE, self).__init__(d, L)
        self.device=device
        self.PE_net = nn.Sequential(
            nn.Linear(111,5*d),
            nn.ReLU(),
            nn.Linear(5*d,5*d),
            nn.ReLU(),
            nn.Linear(5*d,d))

    def forward(self, pe, seq, state):

        position_embeds = self.PE_net(pe.view(-1, 111).to(self.device)).view(-1, self.L, self.d).to(self.device) # N*L*111 -> N*L*d
        position_embeds = position_embeds.permute(0, 2, 1).to(self.device) # N*d*L
        seq = seq.permute(0, 2, 1).to(self.device) # 4*L
        seq = F.relu(self.bn1(self.conv1d1(seq))).to(self.device) #d*L just for increase the capacity

        seq = torch.cat([seq, position_embeds], 1).to(self.device) # 2d*L
        seq = self.transformer_encoder(seq.permute(-1, 0, 1).to(self.device)).to(self.device)
        seq = seq.permute(1, 2, 0).to(self.device)

        seq_mat = self.matrix_rep(seq).to(self.device) # 4d*L*L

        p_mat = self.matrix_rep(position_embeds).to(self.device) # 2d*L*L

        infor = torch.cat([seq_mat, p_mat], 1).to(self.device) # 6d*L*L

        contact = F.relu(self.bn_conv_1(self.conv_test_1(infor))).to(self.device)
        contact = F.relu(self.bn_conv_2(self.conv_test_2(contact))).to(self.device)
        contact = self.conv_test_3(contact).to(self.device)

        contact = contact.view(-1, self.L, self.L).to(self.device)
        contact = ((contact.to(self.device)+torch.transpose(contact, -1, -2).to(self.device))/2).to(self.device)

        return contact.view(-1, self.L, self.L).to(self.device)

Hey @irleader, I think the problem is in lines (there are multiple lines) that try to use self.device in the following way. DataParallel would replicate your model to all provides/visible devices. So, in the above case, there will be four thread, with which thread having one replica of self.PE_net on a different device. However, DataParallel is not smart enough to modify self.device for you. So self.device all all threads will point to the same device, which is the one you passed to ContactAttention_simple_fix_PE() ctor. Hence, there will be a device mis-match.

self.PE_net(pe.view(-1, 111).to(self.device))

If pe is a tensor, DataParallel should have already scattered it to the correct device. Is there any reason for calling .to(self.device) again?

cc @VitalyFedyunin

Hi, Shen Li,

Thanks a lot for pointing this out. Those .to(self.device) were not there, I added them because I am trying to get rid of the error message. This is due to several posts saying that any layer of the model not defined in the init() but used in forward() should add .to(device).

Even if I remove all .to(self.device) from forward(), the error is still there.

I was thinking the error might be caused by def matrix_rep(self, x) which is defined outside of init() but used in forward(), but I have no idea how to modify it.

@ptrblck_de Hi ptrblck, Can you have to look at my code please? I see that you answered lots of similar problems. Thanks.

Which line reported the above error?

Hi,

Th error message was reported on this line:

output = contact_net(PE_batch,seq_embedding_batch, state_pad)

Thanks.

As described by @mrshenli the to(device) calls inside the forward would cause this error and your model works without them:

 import torch
import torch.nn as nn
import torch.nn.functional as F

class ContactAttention_simple(nn.Module):
    def __init__(self, d,L):
        super(ContactAttention_simple, self).__init__()
        self.d = d
        self.L = L
        self.conv1d1= nn.Conv1d(in_channels=4, out_channels=d,
            kernel_size=9, padding=8, dilation=2)
        self.bn1 = nn.BatchNorm1d(d)

        self.conv_test_1 = nn.Conv2d(in_channels=6*d, out_channels=d, kernel_size=1)
        self.bn_conv_1 = nn.BatchNorm2d(d)
        self.conv_test_2 = nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1)
        self.bn_conv_2 = nn.BatchNorm2d(d)
        self.conv_test_3 = nn.Conv2d(in_channels=d, out_channels=1, kernel_size=1)

        self.position_embedding_1d = nn.Parameter(
            torch.randn(1, d, 600)
        )
        self.encoder_layer = nn.TransformerEncoderLayer(2*d, 2)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 3)

    def forward(self, prior, seq, state):

        position_embeds = self.position_embedding_1d.repeat(seq.shape[0],1,1)
        seq = seq.permute(0, 2, 1) # 4*L
        seq = F.relu(self.bn1(self.conv1d1(seq))) #d*L just for increase the capacity

        seq = torch.cat([seq, position_embeds], 1) # 2d*L
        seq = self.transformer_encoder(seq.permute(-1, 0, 1))
        seq = seq.permute(1, 2, 0)
        seq_mat = self.matrix_rep(seq) # 4d*L*L

        p_mat = self.matrix_rep(position_embeds) # 2d*L*L

        infor = torch.cat([seq_mat, p_mat], 1) # 6d*L*L

        contact = F.relu(self.bn_conv_1(self.conv_test_1(infor)))
        contact = F.relu(self.bn_conv_2(self.conv_test_2(contact)))
        contact = self.conv_test_3(contact)

        contact = contact.view(-1, self.L, self.L)
        contact = (contact+torch.transpose(contact, -1, -2))/2

        return contact.view(-1, self.L, self.L)

    def matrix_rep(self, x):
        x = x.permute(0, 2, 1) # L*d
        L = x.shape[1]
        x2 = x
        x = x.unsqueeze(1)
        x2 = x2.unsqueeze(2)
        x = x.repeat(1, L,1,1)
        x2 = x2.repeat(1, 1, L,1)
        mat = torch.cat([x,x2],-1) # L*L*2d
        mat_tril = torch.tril(mat.permute(0, -1, 1, 2)) # 2d*L*L
        mat_diag = mat_tril - torch.tril(mat.permute(0, -1, 1, 2), diagonal=-1)
        mat = mat_tril + torch.transpose(mat_tril, -2, -1) - mat_diag
        return mat


class ContactAttention_simple_fix_PE(ContactAttention_simple):
    def __init__(self, d, L):
        super(ContactAttention_simple_fix_PE, self).__init__(d, L)
        self.PE_net = nn.Sequential(
            nn.Linear(111,5*d),
            nn.ReLU(),
            nn.Linear(5*d,5*d),
            nn.ReLU(),
            nn.Linear(5*d,d))

    def forward(self, pe, seq):
        print(pe.shape, pe.device)

        position_embeds = self.PE_net(pe.view(-1, 111)).view(-1, self.L, self.d) # N*L*111 -> N*L*d
        position_embeds = position_embeds.permute(0, 2, 1) # N*d*L
        seq = seq.permute(0, 2, 1) # 4*L
        seq = F.relu(self.bn1(self.conv1d1(seq))) #d*L just for increase the capacity

        seq = torch.cat([seq, position_embeds], 1) # 2d*L
        seq = self.transformer_encoder(seq.permute(-1, 0, 1))
        seq = seq.permute(1, 2, 0)

        seq_mat = self.matrix_rep(seq) # 4d*L*L

        p_mat = self.matrix_rep(position_embeds) # 2d*L*L

        infor = torch.cat([seq_mat, p_mat], 1) # 6d*L*L

        contact = F.relu(self.bn_conv_1(self.conv_test_1(infor)))
        contact = F.relu(self.bn_conv_2(self.conv_test_2(contact)))
        contact = self.conv_test_3(contact)

        contact = contact.view(-1, self.L, self.L)
        contact = ((contact+torch.transpose(contact, -1, -2))/2)

        return contact.view(-1, self.L, self.L)


model = ContactAttention_simple_fix_PE(1, 111).cuda()
model = nn.DataParallel(model)
x = torch.randn(8, 111, 111).cuda()
seq = torch.randn(8, 111, 4).cuda()
out = model(x, seq)
print(out.shape)

Output without nn.DataParallel:

torch.Size([8, 111, 111]) cuda:0
torch.Size([8, 111, 111])

Output with nn.DataParallel:

torch.Size([1, 111, 111]) cuda:0
torch.Size([1, 111, 111]) cuda:1
torch.Size([1, 111, 111]) cuda:2
torch.Size([1, 111, 111]) cuda:3
torch.Size([1, 111, 111]) cuda:4
torch.Size([1, 111, 111]) cuda:5
torch.Size([1, 111, 111]) cuda:6
torch.Size([1, 111, 111]) cuda:7
torch.Size([8, 111, 111])

Note that I have used random tensors with shapes, which seem to work for this model, as I didn’t see any information regarding the shapes.

@ptrblck Thanks a lot for this. I tried with your code and it works well only on the first step, and then the error occurs again. Here is the full error message:

torch.Size([5, 600, 111]) cuda:0
torch.Size([5, 600, 111]) cuda:1
torch.Size([5, 600, 111]) cuda:2
torch.Size([5, 600, 111]) cuda:3
Stage 1, epoch: 0,step: 0, loss: 0.5405522584915161
torch.Size([2, 600, 111]) cuda:0
torch.Size([2, 600, 111]) cuda:1
torch.Size([1, 600, 111]) cuda:2
Traceback (most recent call last):
File “e2e_learning_stage1.py”, line 349, in
output = contact_net(PE_batch,seq_embedding_batch, state_pad)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/_utils.py”, line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(input, **kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(input, **kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/_utils.py”, line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(input, **kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(input, **kwargs)
File “/g/data/ik06/jiajia/e2efold_master/e2efold/models.py”, line 252, in forward
position_embeds = self.PE_net(pe.view(-1, 111)).view(-1, self.L, self.d) # N
L
111 -> N
L
d
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/modules/container.py”, line 117, in forward
input = module(input)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/modules/linear.py”, line 93, in forward
return F.linear(input, self.weight, self.bias)
File “/g/data/ik06/jiajia/python3packages/lib/python3.8/site-packages/torch/nn/functional.py”, line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected tensor for ‘out’ to have the same device as tensor for argument #2 ‘mat1’; but device 0 does not equal 1 (while check
ing arguments for addmm)

I am using batch size 20, and 4 GPUs. I will also attach my training code here if helpful:

for epoch in range(epoches_first):
    for contacts, seq_embeddings, matrix_reps, seq_lens in train_generator:
        contact_net.train()
        contacts_batch = torch.Tensor(contacts.float()).cuda()
        seq_embedding_batch = torch.Tensor(seq_embeddings.float()).cuda()
        matrix_reps_batch = torch.unsqueeze(torch.Tensor(matrix_reps.float()).cuda(), -1)
        state_pad = torch.zeros([matrix_reps_batch.shape[0],seq_len, seq_len]).cuda()
        PE_batch = get_pe(seq_lens, seq_len).float().cuda()
        contact_masks = torch.Tensor(contact_map_masks(seq_lens, seq_len)).cuda()

        contact_net = torch.nn.DataParallel(contact_net)
        output = contact_net(PE_batch,seq_embedding_batch, state_pad)
        # Compute loss
        loss_u = criterion_bce_weighted(output*contact_masks, contacts_batch)
        # print(steps_done)
        if steps_done % OUT_STEP ==0:
            print('Stage 1, epoch: {},step: {}, loss: {}'.format(
                epoch, steps_done, loss_u))

        # Optimize the model
        u_optimizer.zero_grad()
        loss_u.backward()
        u_optimizer.step()
        steps_done=steps_done+1

Just to add the torch_seed() method I used to make sure it’s not causing trouble:

def seed_torch(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

Interestingly, it seems you are getting the error, if the batches are imbalanced, i.e. if the data cannot be equally split to the devices.
I just tested my code with a batch size of 5 (which is your last case before the error is raised) and it still works fine. Could you do the same and check, if you are seeing an error?
Also, are you using the latest PyTorch version?

@ptrblck Thanks again for your time. I have tested with batch size of 5 with 4 GPUs. The same error occurs:

torch.Size([2, 600, 111]) cuda:0
torch.Size([2, 600, 111]) cuda:1
torch.Size([1, 600, 111]) cuda:2
Stage 1, epoch: 0,step: 0, loss: 0.5292213559150696
torch.Size([1, 600, 111]) cuda:0
torch.Size([1, 600, 111]) cuda:1
Traceback (most recent call last):
File “e2e_learning_stage1.py”, line 343, in
output = contact_net(PE_batch,seq_embedding_batch, state_pad)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 161, in fo

rward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 171, in pa
rallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in pa
rallel_apply
output.reraise()
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/_utils.py”, line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _w
orker
output = module(input, **kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(input, **kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 161, in fo
rward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py”, line 171, in pa
rallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in pa
rallel_apply
output.reraise()
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/_utils.py”, line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _w
orker
output = module(input, **kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(input, **kwargs)
File “/g/data/ik06/jiajia/e2efold_master/e2efold/models.py”, line 252, in forward
position_embeds = self.PE_net(pe.view(-1, 111)).view(-1, self.L, self.d) # N
L
111 -> N
L
d
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py”, line 117, in forward
input = module(input)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py”, line 93, in forward
return F.linear(input, self.weight, self.bias)
File “/home/248/jx3129/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py”, line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected tensor for ‘out’ to have the same device as tensor for argument #2 ‘mat1’; but device 0 does
not equal 1 (while checking arguments for addmm)

I am using pytorch 1.7.0.

The second iteration is now using only a batch size of 2 (or 3 in case the script is crashing before cuda:2 is executed), so in your script something is reducing the batch size. Could you post an executable code snippet (using my template) to reproduce this issue?

@ptrblck I am so sorry that the codes I am running is a big project which contains several scripts, it’s hard to merge them into one executable code snippet, can you give me your email or I upload to google drive and give you the link? Is that OK?

@ptrblck Thanks in advance for your time. I managed to merge all scripts into one and reproduced the error with it. The raw data I am using is uploaded to google doc , called test.pickle, put the script and data in the same directory: https://drive.google.com/file/d/1DxVtd9ejMns644EoF31Mf8JjQxK94JsH/view?usp=sharing

I attach my code here:

All functions and models needed:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import os
import torch.optim as optim
import math
import numpy as np
import _pickle as cPickle
import collections
from random import shuffle

os.environ["CUDA_VISIBLE_DEVICES"]= '0,1,2,3'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## model
class ContactAttention_simple(nn.Module):
    """docstring for ContactAttention_simple"""
    def __init__(self, d,L):
        super(ContactAttention_simple, self).__init__()
        self.d = d
        self.L = L
        self.conv1d1= nn.Conv1d(in_channels=4, out_channels=d,
            kernel_size=9, padding=8, dilation=2)
        self.bn1 = nn.BatchNorm1d(d)

        self.conv_test_1 = nn.Conv2d(in_channels=6*d, out_channels=d, kernel_size=1)
        self.bn_conv_1 = nn.BatchNorm2d(d)
        self.conv_test_2 = nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1)
        self.bn_conv_2 = nn.BatchNorm2d(d)
        self.conv_test_3 = nn.Conv2d(in_channels=d, out_channels=1, kernel_size=1)

        self.position_embedding_1d = nn.Parameter(
            torch.randn(1, d, 600)
        )

        # transformer encoder for the input sequences
        self.encoder_layer = nn.TransformerEncoderLayer(2*d, 2)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 3)

    def forward(self, prior, seq, state):
        """
        prior: L*L*1
        seq: L*4
        state: L*L
        """

        position_embeds = self.position_embedding_1d.repeat(seq.shape[0],1,1)
        seq = seq.permute(0, 2, 1) # 4*L
        seq = F.relu(self.bn1(self.conv1d1(seq))) #d*L just for increase the capacity

        seq = torch.cat([seq, position_embeds], 1) # 2d*L
        seq = self.transformer_encoder(seq.permute(-1, 0, 1))
        seq = seq.permute(1, 2, 0)

        # what about apply attention on the the 2d map?
        seq_mat = self.matrix_rep(seq) # 4d*L*L

        p_mat = self.matrix_rep(position_embeds) # 2d*L*L

        infor = torch.cat([seq_mat, p_mat], 1) # 6d*L*L

        contact = F.relu(self.bn_conv_1(self.conv_test_1(infor)))
        contact = F.relu(self.bn_conv_2(self.conv_test_2(contact)))
        contact = self.conv_test_3(contact)

        contact = contact.view(-1, self.L, self.L)
        contact = (contact+torch.transpose(contact, -1, -2))/2

        return contact.view(-1, self.L, self.L)

    def matrix_rep(self, x):
        '''
        for each position i,j of the matrix, we concatenate the embedding of i and j
        '''
        x = x.permute(0, 2, 1) # L*d
        L = x.shape[1]
        x2 = x
        x = x.unsqueeze(1)
        x2 = x2.unsqueeze(2)
        x = x.repeat(1, L,1,1)
        x2 = x2.repeat(1, 1, L,1)
        mat = torch.cat([x,x2],-1) # L*L*2d

        # make it symmetric
        # mat_tril = torch.cat(
        #     [torch.tril(mat[:,:, i]) for i in range(mat.shape[-1])], -1)
        mat_tril = torch.tril(mat.permute(0, -1, 1, 2)) # 2d*L*L
        mat_diag = mat_tril - torch.tril(mat.permute(0, -1, 1, 2), diagonal=-1)
        mat = mat_tril + torch.transpose(mat_tril, -2, -1) - mat_diag
        return mat


class ContactAttention_simple_fix_PE(ContactAttention_simple):
    """docstring for ContactAttention_simple_fix_PE"""
    def __init__(self, d, L, device):
        super(ContactAttention_simple_fix_PE, self).__init__(d, L)
        self.PE_net = nn.Sequential(
            nn.Linear(111,5*d),
            nn.ReLU(),
            nn.Linear(5*d,5*d),
            nn.ReLU(),
            nn.Linear(5*d,d))

    def forward(self, pe, seq, state):
        """
        prior: L*L*1
        seq: L*4
        state: L*L
        """
        print(pe.shape, pe.device)
        position_embeds = self.PE_net(pe.view(-1, 111)).view(-1, self.L, self.d) # N*L*111 -> N*L*d
        position_embeds = position_embeds.permute(0, 2, 1) # N*d*L
        seq = seq.permute(0, 2, 1) # 4*L
        seq = F.relu(self.bn1(self.conv1d1(seq))) #d*L just for increase the capacity

        seq = torch.cat([seq, position_embeds], 1) # 2d*L
        seq = self.transformer_encoder(seq.permute(-1, 0, 1))
        seq = seq.permute(1, 2, 0)

        # what about apply attention on the the 2d map?
        seq_mat = self.matrix_rep(seq) # 4d*L*L

        p_mat = self.matrix_rep(position_embeds) # 2d*L*L

        infor = torch.cat([seq_mat, p_mat], 1) # 6d*L*L

        contact = F.relu(self.bn_conv_1(self.conv_test_1(infor)))
        contact = F.relu(self.bn_conv_2(self.conv_test_2(contact)))
        contact = self.conv_test_3(contact)

        contact = contact.view(-1, self.L, self.L)
        contact = (contact+torch.transpose(contact, -1, -2))/2

        return contact.view(-1, self.L, self.L)


char_dict = {
    0: 'A',
    1: 'U',
    2: 'C',
    3: 'G'
}

def encoding2seq(arr):
	seq = list()
	for arr_row in list(arr):
		if sum(arr_row)==0:
			seq.append('.')
		else:
			seq.append(char_dict[np.argmax(arr_row)])
	return ''.join(seq)

class RNASSDataGenerator(object):
    def __init__(self, data_dir, split, upsampling=False):
        self.data_dir = data_dir
        self.split = split
        self.upsampling = upsampling
        # Load vocab explicitly when needed
        self.load_data()
        # Reset batch pointer to zero
        self.batch_pointer = 0

    def load_data(self):
        data_dir = self.data_dir
        # Load the current split
        RNA_SS_data = collections.namedtuple('RNA_SS_data','seq ss_label length name pairs')
        with open(os.path.join(data_dir, '%s.pickle' % self.split), 'rb') as f:
            self.data = cPickle.load(f)
        #if self.upsampling:
        #    self.data = self.upsampling_data()
        self.data_x = np.array([instance[0] for instance in self.data])
        self.data_y = np.array([instance[1] for instance in self.data])
        self.pairs = np.array([instance[-1] for instance in self.data])
        self.seq_length = np.array([instance[2] for instance in self.data])
        self.len = len(self.data)

        self.seq = list(map(encoding2seq, self.data_x))
        self.seq_max_len = len(self.data_x[0])


    def pairs2map(self, pairs):
        seq_len = self.seq_max_len
        contact = np.zeros([seq_len, seq_len])
        for pair in pairs:
            contact[pair[0], pair[1]] = 1
        return contact

    def get_one_sample(self, index):
        # This will return a smaller size if not sufficient
        # The user must pad the batch in an external API
        # Or write a TF module with variable batch size
        data_y = self.data_y[index]
        data_seq = self.data_x[index]
        data_len = self.seq_length[index]
        data_pair = self.pairs[index]

        contact= self.pairs2map(data_pair)
        matrix_rep = np.zeros(contact.shape)
        return contact, data_seq, matrix_rep, data_len

class Dataset(data.Dataset):
  def __init__(self, data):
        self.data = data

  def __len__(self):
        return self.data.len

  def __getitem__(self, index):
        return self.data.get_one_sample(index)


#position embedding
def get_pe(seq_lens, max_len):
    #batch_size*1--> batch_size N
    num_seq = seq_lens.shape[0]
    #absolute position: from 1 to 600 : N*L*1
    pos_i_abs = torch.Tensor(np.arange(1,max_len+1)).view(1,
        -1, 1).expand(num_seq, -1, -1).double()
    #relatve position: from 1 to 600: N*L
    pos_i_rel = torch.Tensor(np.arange(1,max_len+1)).view(1, -1).expand(num_seq, -1)
    # N*L/N*1 --> N*L
    pos_i_rel = pos_i_rel.double()/seq_lens.view(-1, 1).double()
    pos_i_rel = pos_i_rel.unsqueeze(-1) #N*L*1
    pos = torch.cat([pos_i_abs, pos_i_rel], -1) #N*L*2

    PE_element_list = list()
    # 1/x, 1/x^2
    PE_element_list.append(pos) #N*L*2
    PE_element_list.append(1.0/pos_i_abs) #N*L*1
    PE_element_list.append(1.0/torch.pow(pos_i_abs, 2)) #N*L*1

    # sin(nx)
    for n in range(1, 50):
        PE_element_list.append(torch.sin(n*pos)) # 49(N*L*2)

    # poly
    for i in range(2, 5):
        PE_element_list.append(torch.pow(pos_i_rel, i)) #3(N*L*1)

    for i in range(3):
        gaussian_base = torch.exp(-torch.pow(pos,
            2))*math.sqrt(math.pow(2,i)/math.factorial(i))*torch.pow(pos, i)
        PE_element_list.append(gaussian_base) #3(N*L*2)

    PE = torch.cat(PE_element_list, -1) #N*L*111
    # zero padding
    for i in range(num_seq):
        PE[i, seq_lens[i]:, :] = 0
    return PE

def contact_map_masks(seq_lens, max_len):
    n_seq = len(seq_lens) #N
    masks = np.zeros([n_seq, max_len, max_len]) #N*L*L
    for i in range(n_seq):
        l = int(seq_lens[i].cpu().numpy())
        masks[i, :l, :l]=1
    return masks

Data and Model Training:

train_data = RNASSDataGenerator('./','test')
seq_len = train_data.data_y.shape[-2]

params = {'batch_size': 8,
          'shuffle': True,
          'num_workers': 0,
          'drop_last': True}
train_set = Dataset(train_data)
train_generator = data.DataLoader(train_set, **params)



contact_net = ContactAttention_simple_fix_PE(d=10, L=seq_len, device=device).to(device)
u_optimizer = optim.Adam(contact_net.parameters())
pos_weight = torch.Tensor([300]).to(device)
criterion_bce_weighted = torch.nn.BCEWithLogitsLoss(
    pos_weight = pos_weight)


steps_done = 0
for epoch in range(50):

    for contacts, seq_embeddings, matrix_reps, seq_lens in train_generator:

        contact_net.train()
        contacts_batch = torch.Tensor(contacts.float()).to(device)
        seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)
        matrix_reps_batch = torch.unsqueeze(torch.Tensor(matrix_reps.float()).to(device), -1)
        # padding the states for supervised training with all 0s
        state_pad = torch.zeros([matrix_reps_batch.shape[0],seq_len, seq_len]).to(device)
        PE_batch = get_pe(seq_lens, seq_len).float().to(device)
        contact_masks = torch.Tensor(contact_map_masks(seq_lens, seq_len)).to(device)
        contact_net = torch.nn.DataParallel(contact_net)
        output = contact_net(PE_batch,seq_embedding_batch, state_pad)

        # Compute loss
        loss_u = criterion_bce_weighted(output*contact_masks, contacts_batch)

        # print(steps_done)
        if steps_done % 100 ==0:
            print('Stage 1, epoch: {},step: {}, loss: {}'.format(
                epoch, steps_done, loss_u))

        # Optimize the model
        u_optimizer.zero_grad()
        loss_u.backward()
        u_optimizer.step()
        steps_done=steps_done+1

Are you seeing the same issue using random data? If so, could you post the shapes so that I could reproduce it without downloading your dataset?

@ptrblck Unfortunately not, since I am using pre defined functions (class RNASSDataGenerator) to pre-process the input data, I have to stick with the data format. I doubt the class RNASSDataGeneraor is causing the trouble.

When I use cpu, everything is ok with the batch shape for every step.

torch.Size([8, 600, 111]) cpu
Stage 1, epoch: 0,step: 0, loss: 0.7845466136932373
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu
torch.Size([8, 600, 111]) cpu

It seems the error is raised by rewrapping the model into nn.DataParallel in each iteration.
Move contact_net = torch.nn.DataParallel(contact_net) before the epoch loop and it should work.

I don’t know, why this usage gives a device mismatch error and think a better error message should be raised. Could you create a GitHub issue, so that we could track and fix it?

1 Like

@ptrblck Thanks a lot for this. It works after your suggestion. I will create a GitHub issue later.
I have a further question regarding saving and loading wrapped model.

In this project, I will need to save the trained model and load it for use later.

If I wrap the model like this:

wrapped_contact_net=torch.nn.DataParallel(contact_net)

I am supposed to save it like this?

        try:
            state_dict = wrapped_contact_net.module.state_dict()
        except AttributeError:
            state_dict = wrapped_contact_net.state_dict()
        torch.save(state_dict, model_path)

If I want to retrain the trained model, do I have to wrap the model name before I load?

wrapped_contact_net=torch.nn.DataParallel(contact_net)
wrapped_contact_net.load_state_dict(torch.load(model_path,map_location=device))

or wrap the model after I load?

contact_net.load_state_dict(torch.load(model_path,map_location=device))
wrapped_contact_net=torch.nn.DataParallel(contact_net)

Thanks in advance.

I would store the model.module.state_dict() to make it independent from nn.DataParallel.
This would also mean your second approach is correct, i.e. create the model, load the state_dict, wrap it in nn.DataParallel.

@ptrblck Thanks,I am OK with loading now. As for saving, do you mean I should save with the original model name, like:

state_dict = contact_net.module.state_dict()

instead of using the wrapped model:

state_dict = wrapped_contact_net.module.state_dict()

You should use the nn.DataParallel object, so wrapped_contact_net in your case (I just used model as a placeholder, as usually you would just override the model variable).

1 Like