Monitoring loss without synchronisation points

Hi,
I am wondering what would be the best way to save the loss of every batch inside a training loop without creating a synchronisation point between GPU and CPU. Because the way I was doing it was to create a list on the CPU and adding every batch loss as a tensor to this list during the training loop (without calling .item or .cpu methods). But I am afraid that the CPU still has to wait for the GPU to finish the backward pass before being able to add an element to its list.

To put things in context, I have a dataloader and I want to use multiprocessing dataloading ( torch.utils.data.DataLoader with num workers >0 ) for the GPU not to wait for data and be used asynchronously with the CPU.

My code looks like this :

model = model.cuda()
device = torch.device("cuda")
batch_losses = []
for idx, (input, label) in enumerate(tqdm(datasetloader)):

     input = input.to(device)
     optim.zero_grad(set_to_none=True)
     prediction = model(input)
     loss = loss_func(input, prediction) 
     loss.backward()
     optim.step()
     batch_losses.append(loss.detach())

I guess my first question is : Does appending an element to a list on CPU create a synchronisation point with the GPU? I was not able to use torch.cuda.set_sync_debug_mode with my torch version.
If yes, would the solution be to create a tensor on the GPU before the loop, concatenating detached loss tensors at every iteration with this tensor and transfer it to CPU after the training loop?

Also, does using tqdm create synchronisation point?

Thank you for your help.

Appending a CPUTensor to a list doesn’t create the synchronization point, but moving a CUDATensor to the CPU will.

You might want to update PyTorch to be able to use this debug utility.

Your current approach should not create a synchronization point as you are not moving the data to the CPU as seen here:

torch.cuda.set_sync_debug_mode(debug_mode="default")

# setup
model = nn.Linear(1, 1).cuda()
x = torch.randn(10, 1, 1, device='cuda')
y = torch.zeros(1).long().cuda()

torch.cuda.set_sync_debug_mode(debug_mode="error")
losses = []
for i in range(len(x)):
    x_ = x[i]
    out = model(x_)
    loss = F.cross_entropy(out, y)
    losses.append(loss.detach()) # works
    # losses.append(loss.cpu().detach()) # RuntimeError: called a synchronizing CUDA operation

I don’t know but wouldn’t think so. Once you’ve updated your PyTorch version you could check it with the debug mode.

1 Like

Thank you very much for these invaluable insights!