How to check if a batch is zero during training

I am using the below code to load a batch during training. However, my batch may have all zero labels. Thus, the gradient will be going to zeros. I want to ignore the case. How could I check if a batch of label become zero during training? Thanks

 for index, batch in enumerate(train_data_loader):
            images, targets = batch
            # Is my code correct?
            if (torch.sum(targets)==0):
               continue;
       

You should be able to do target.nonzero().any()

1 Like

Thanks. Which one is faster?