I think I have found out the cause of the problem, here is my code, it is a runable demo and you need to install:
torch==‘1.7.1+cu110’
transformers==4.3.2
import transformers
import torch
from transformers import ElectraModel
from transformers import AdamW
from tqdm import tqdm
class Encoder(torch.nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.encoder = ElectraModel.from_pretrained("google/electra-large-discriminator")
def forward(self, x1, x2):
return [self.encoder(x1), self.encoder(x2)]
class BigModel(torch.nn.Module):
def __init__(self):
super(BigModel, self).__init__()
self.encoder = torch.nn.DataParallel(Encoder(),device_ids=["cuda:0", "cuda:1", "cuda:2", "cuda:3"] )
self.l1 = torch.nn.Linear(1024, 1024)
self.l2 = torch.nn.Linear(1024, 1024)
self.l3 = torch.nn.Linear(1024, 2)
self.loss = torch.nn.CrossEntropyLoss()
def deploy(self):
self.l1 = self.l1.to("cuda:4")
self.l2 = self.l2.to("cuda:4")
self.l3 = self.l3.to("cuda:5")
self.loss = self.loss.to("cuda:5")
def forward(self, x: torch.LongTensor, x1: torch.LongTensor, y: torch.LongTensor):
res = self.encoder(x, x1)
x2 = self.l1(res[0][0].to("cuda:4")).to("cuda:5")
x3 = self.l2(res[1][0].to("cuda:4")).to("cuda:5")
print(x2.shape)
print(x3.shape)
x4 = self.l3((x2.unsqueeze(1) + x3.unsqueeze(0)).mean(1))
return self.loss(x4.reshape((-1, 2)), y.reshape((-1,))).mean()
def main():
model = BigModel().cuda()
model.deploy()
optimizer = AdamW(model.parameters(), lr=3e-5)
# assume we train a token binary classificaiton task for 10 epochs, and the data set has 100 batch
for epoch in range(10):
for d in tqdm(range(100)):
ids = torch.randint(0, 30000, (16, 256))
ids1 = torch.randint(0, 30000, (6, 256)) # strange size that lead to the problem
y = torch.randint(0, 2, (16, 256))
model.zero_grad()
ids = ids.cuda()
ids1 = ids1.cuda()
y = y.to("cuda:5")
loss = model(ids, ids1, y)
loss.backward()
optimizer.step()
main()
And after I start this python program, and get the error output
0%| | 0/100 [00:00<?, ?it/s]
torch.Size([12, 256, 1024])
torch.Size([6, 256, 1024])
0%| | 0/100 [00:09<?, ?it/s]
Traceback (most recent call last):
File "demo.py", line 58, in <module>
main()
File "demo.py", line 53, in main
loss = model(ids, ids1, y)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "demo.py", line 36, in forward
return self.loss(x4.reshape((-1, 2)), y.reshape((-1,))).mean()
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 962, in forward
ignore_index=self.ignore_index, reduction=self.reduction)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2468, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2262, in nll_loss
.format(input.size(0), target.size(0)))
ValueError: Expected input batch_size (3072) to match target batch_size (4096).
This error happens in loss function and torch not check the the input batch size of all input tensors. And at the same time, a split of one encode result is lost.