Using the same code on single gpu give a different loss curve:
But using the same code on single node multi-gpu give random results:
Here is my trainer class to handle multi-gpu training:
class Trainer:
def __init__(self, model, train_data, val_data, optimizer, gpu_id, save_every):
self.gpu_id = gpu_id
self.model = model.to(gpu_id)
self.train_data = train_data
self.val_data = val_data
self.optimizer = optimizer
self.save_every = save_every
self.model = DDP(model, device_ids=[gpu_id], find_unused_parameters=True)
def _run_batch(self, source, targets):
self.optimizer.zero_grad()
output = self.model(source)
loss = F.cross_entropy(output, targets)
loss.backward()
self.optimizer.step()
return loss
def _run_epoch(self, epoch):
b_sz = len(next(iter(self.train_data))[0])
# print(f"GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
self.train_data.sampler.set_epoch(epoch)
loss_ = []
for source, targets in self.train_data:
source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)
loss = self._run_batch(source, targets)
loss_.append(loss.item())
return np.mean(loss_)
def _run_val_epoch(self, epoch):
b_sz = len(next(iter(self.val_data))[0])
# print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
self.val_data.sampler.set_epoch(epoch)
loss_ = []
for source, targets in self.val_data:
source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)
output = self.model(source)
loss = F.cross_entropy(output, targets)
loss_.append(loss.item())
return np.mean(loss_)
def _save_checkpoint(self, epoch):
ckp = self.model.module.state_dict()
PATH = "ddp_checkpoint.pt"
torch.save(ckp, PATH)
print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")
def train(self, max_epochs):
total_loss = {}
total_loss['train_loss'] = []
total_loss['val_loss'] = []
for epoch in range(max_epochs):
train_loss = self._run_epoch(epoch)
val_loss = self._run_val_epoch(epoch)
total_loss['train_loss'].append(train_loss)
total_loss['val_loss'].append(val_loss)
print(f"Epoch: {epoch}")
print(f"train loss: {train_loss}, val_loss: {val_loss}")
if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_checkpoint(epoch)
with open('loss.txt', 'w') as f:
f.write(str(total_loss))
f.close()