How to skip update if some of mini-batch is not fit on condition?

Hello.
I want to skip update and go to next iteration. For example

for it, input in dataloader:
    optimizer.zero_grad()
    output = model(input)
    loss = myloss(output, label) 
    if loss.item()==0 or not torch.isfinite(loss):
        loss.backward()
        continue

    loss.backward()
    optimizer.step()

I want to do above like that.

First , once I got loss as ‘nan’ , and continue to get ‘nan’ from then on.
I think that weight got updated based on gradient by ‘nan’ and continue.
How to solve it?

Second. I make loss 0 on purpose for no updating weight. (To make gradient 0 , weight not gonna change)
Because some of mini-batch is not fit on condition.
For example.
I got 8 mini-batch.

class myloss():
  def __init__():
        blah, blah...

  def forward(self, x, A_list):
    A_list =  [None, 1, None , 1, 1, 1, 1, 1 ] # for example
    batch_loss = torch.tensor(0.0)
    for a in A_list:
      if not a == None:
         loss = criterion(x)
         batch_loss += loss
      else:
         loss = torch.log(x/x) #( make  loss 0 on purpose.. is it right way?)
         batch_loss +=loss
    return batch_loss

I don’t know that making loss 0 on purpose is right.
Is someone who have good idea for that?

Thanks in advance.

Calling loss.backward() if the loss it not finite and thus contains invalid values wouldn’t make sense, but since optimizer.step() wasn’t called the parameters should not be invalid.
You could verify it by checking all parameters of the model once the update step is skipped.
However, depending where the invalid activations were created in the model, e.g. the running stats of batchnorm layers might have been updated with these NaN/Inf values, which would then break your model in the validation run, so you should try to avoid creating invalid activations at all.

I would not recommend this approach as even if the gradients are zero, the optimizer could still update the parameters if it holds internal states from the last updates (such as Adam).
If you want to skip a batch, skip it explicitly instead of manipulating the gradients.

@ptrblck
Thanks your reply.
Summary. 1. Verifying all parameters after model update step is skipped.
2. Optimizer could update with internal states from the last updates…so making gradient 0 is not appropriate.
SO…

If you want to skip a batch, skip it explicitly instead of manipulating the gradients.

I am wondering how to skip it explicitly??? Could you give me a example to understand in detail?
Thank you!

Maybe just use a condition on the backward and step call:

if loss.item()!=0 or torch.isfinite(loss):
    loss.backward()
    optimizer.step()

@ptrblck
Really thanks your reply.
But, I am really confusing… I give you another example.
What if A_list = [None, None, None , None, None, None, None, None ] instead of [None, 1, None , 1, 1, 1, 1, 1 ]. (A_list is supplementary mean to filter data which is not fit my idea)
According to your answer, if I erase part of “loss 0 on purpose” , myloss would return nothing like below.
I think it couldn’t do back-propagation because loss is nothing.
On this issue, is it possible to skip update explicitly?
What did i misunderstand?

class myloss():
  def __init__():
        blah, blah...

  def forward(self, x, A_list):
    A_list =  [None, None, None , None, None, None, None, None ]# for example
    batch_loss = torch.tensor(0.0)
    for a in A_list:
      if not a == None:
         loss = criterion(x)
         batch_loss += loss
    return batch_loss

And How to initialize batch_loss? torch.tensor(0.0) may be wrong.

@ptrblck
I’m so sorry bothering you…Once again, Could you tell me any solution for this?
Thank you.

I didn’t understand exactly.
If you need a small chat, message to me

Yes, it should be possible to skip the backward and optimizer.step() calls by e.g. checking if the returned batch_loss tensor has a valid .grad_fn attribute or if it’s set to None.

@ptrblck
Really appreciate you. Thecho7 let me know set None if to skip update. Thanks you both of you.