Error message:
Expected tensor for ‘out’ to have the same device as tensor for argument #2 ‘mat1’; but device 1 does not equal 0 (while checki
ng arguments for addmm)
I understand this error has been discussed quite a lot and after reading several posts I had a basic idea of why this occurs on my code. Mainly because I am using a very complex model.
My code works fine on single-GPU mode, after adding torch.nn.DataParallel, I tried to run on a 4-GPU node, the error occurred. Can someone kindly have a look at my model and point out where to modify please?
CUDA Setting:
os.environ["CUDA_VISIBLE_DEVICES"]= '0,1,2,3'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Input variable CUDA setting:
PE_batch = get_pe(seq_lens, seq_len).float().to(device)
seq_embedding_batch = torch.Tensor(seq_embeddings.float()).to(device)
state_pad = torch.zeros([matrix_reps_batch.shape[0],seq_len, seq_len]).to(device)
Model instantiation and application:
contact_net = ContactAttention_simple_fix_PE(d=d, L=seq_len, device=device).to(device)
contact_net = torch.nn.DataParallel(contact_net)
output = contact_net(PE_batch,seq_embedding_batch, state_pad)
Models details (problem should be here, I am using a subcalss of a class from nn.Module,so there are two model classes. During my debug, I have added .to(self.device) to every operation in forward() in case I miss any of the layers):
class ContactAttention_simple(nn.Module):
def __init__(self, d,L):
super(ContactAttention_simple, self).__init__()
self.d = d
self.L = L
self.conv1d1= nn.Conv1d(in_channels=4, out_channels=d,
kernel_size=9, padding=8, dilation=2)
self.bn1 = nn.BatchNorm1d(d)
self.conv_test_1 = nn.Conv2d(in_channels=6*d, out_channels=d, kernel_size=1)
self.bn_conv_1 = nn.BatchNorm2d(d)
self.conv_test_2 = nn.Conv2d(in_channels=d, out_channels=d, kernel_size=1)
self.bn_conv_2 = nn.BatchNorm2d(d)
self.conv_test_3 = nn.Conv2d(in_channels=d, out_channels=1, kernel_size=1)
self.position_embedding_1d = nn.Parameter(
torch.randn(1, d, 600)
)
self.encoder_layer = nn.TransformerEncoderLayer(2*d, 2)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, 3)
def forward(self, prior, seq, state):
position_embeds = self.position_embedding_1d.repeat(seq.shape[0],1,1)
seq = seq.permute(0, 2, 1) # 4*L
seq = F.relu(self.bn1(self.conv1d1(seq))) #d*L just for increase the capacity
seq = torch.cat([seq, position_embeds], 1) # 2d*L
seq = self.transformer_encoder(seq.permute(-1, 0, 1))
seq = seq.permute(1, 2, 0)
seq_mat = self.matrix_rep(seq) # 4d*L*L
p_mat = self.matrix_rep(position_embeds) # 2d*L*L
infor = torch.cat([seq_mat, p_mat], 1) # 6d*L*L
contact = F.relu(self.bn_conv_1(self.conv_test_1(infor)))
contact = F.relu(self.bn_conv_2(self.conv_test_2(contact)))
contact = self.conv_test_3(contact)
contact = contact.view(-1, self.L, self.L)
contact = (contact+torch.transpose(contact, -1, -2))/2
return contact.view(-1, self.L, self.L)
def matrix_rep(self, x):
x = x.permute(0, 2, 1) # L*d
L = x.shape[1]
x2 = x
x = x.unsqueeze(1)
x2 = x2.unsqueeze(2)
x = x.repeat(1, L,1,1)
x2 = x2.repeat(1, 1, L,1)
mat = torch.cat([x,x2],-1) # L*L*2d
mat_tril = torch.tril(mat.permute(0, -1, 1, 2)) # 2d*L*L
mat_diag = mat_tril - torch.tril(mat.permute(0, -1, 1, 2), diagonal=-1)
mat = mat_tril + torch.transpose(mat_tril, -2, -1) - mat_diag
return mat
class ContactAttention_simple_fix_PE(ContactAttention_simple):
def __init__(self, d, L, device):
super(ContactAttention_simple_fix_PE, self).__init__(d, L)
self.device=device
self.PE_net = nn.Sequential(
nn.Linear(111,5*d),
nn.ReLU(),
nn.Linear(5*d,5*d),
nn.ReLU(),
nn.Linear(5*d,d))
def forward(self, pe, seq, state):
position_embeds = self.PE_net(pe.view(-1, 111).to(self.device)).view(-1, self.L, self.d).to(self.device) # N*L*111 -> N*L*d
position_embeds = position_embeds.permute(0, 2, 1).to(self.device) # N*d*L
seq = seq.permute(0, 2, 1).to(self.device) # 4*L
seq = F.relu(self.bn1(self.conv1d1(seq))).to(self.device) #d*L just for increase the capacity
seq = torch.cat([seq, position_embeds], 1).to(self.device) # 2d*L
seq = self.transformer_encoder(seq.permute(-1, 0, 1).to(self.device)).to(self.device)
seq = seq.permute(1, 2, 0).to(self.device)
seq_mat = self.matrix_rep(seq).to(self.device) # 4d*L*L
p_mat = self.matrix_rep(position_embeds).to(self.device) # 2d*L*L
infor = torch.cat([seq_mat, p_mat], 1).to(self.device) # 6d*L*L
contact = F.relu(self.bn_conv_1(self.conv_test_1(infor))).to(self.device)
contact = F.relu(self.bn_conv_2(self.conv_test_2(contact))).to(self.device)
contact = self.conv_test_3(contact).to(self.device)
contact = contact.view(-1, self.L, self.L).to(self.device)
contact = ((contact.to(self.device)+torch.transpose(contact, -1, -2).to(self.device))/2).to(self.device)
return contact.view(-1, self.L, self.L).to(self.device)