Getting nan values after first batch

Hi,
I’m trying to reproduce the paper:
Counting Out Time: Class Agnostic Video Repetition Counting in the Wild
https://openaccess.thecvf.com/content_CVPR_2020/papers/Dwibedi_Counting_Out_Time_Class_Agnostic_Video_Repetition_Counting_in_the_CVPR_2020_paper.pdf

After the first batch I’m getting nan as outputs, my loss.grad=None.
My inputs are normalized.
I’m adding a snippet of the code.

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-3])
        self.temporal_context = nn.Sequential(nn.Conv3d(in_channels=1024,
                                out_channels=512,kernel_size=3,padding=[1,0,0]), nn.ReLU())
   

    def forward(self,x):
        b, f, c, h, w = x.size()
        features = self.cnn(x.view(b*f, c, h, w))
        temp_features = self.temporal_context(features.unsqueeze(0).view(b,features.shape[1], f,features.shape[2],features.shape[3]))
        out = nn.AvgPool2d(5)(temp_features.view(b*f,512, temp_features.shape[-1], temp_features.shape[-1])).view(b, f, -1)
        return out

class PeriodPredictor(nn.Module):

    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3), nn.ReLU())
        self.projection = nn.Linear(1922,512)
        self.transformer = nn.Transformer(nhead=4,dim_feedforward=128)

    def forward(self, x):
        out = self.cnn(x)
        out = out.view(x.size(0), 64, -1)
        out = self.projection(out)
        out = self.transformer(out, out)
        return out

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.l_network = nn.Sequential(nn.Linear(512,512),nn.ReLU(),nn.Linear(512,32))
        self.p_network = nn.Sequential(nn.Linear(512,512),nn.ReLU(),nn.Linear(512,1))

    def forward(self, x):
        b, f, e = x.size()
        l = self.l_network(x.view(b*f,-1))
        p = self.p_network(x.view(b*f,-1))
        return l.view(b,-1), p.view(b,-1)

class RepNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.period_predictor = PeriodPredictor()
        self.classifier = Classifier()
        self.optimizer = torch.optim.SGD(params=self.parameters(),  lr=1e-3, momentum=0.9, weight_decay=5e-4)

    def forward(self, x):
        features = self.encoder(x)
        sim_mat = -torch.matrix_power(torch.cdist(features, features), 2).unsqueeze(1)
        period = self.period_predictor(sim_mat)
        l, p = self.classifier(period)
        return l, p

    def train_epoch(self, loader, epoch, device):
        self.train()
        print(
            "\n" + " " * 10 + "****** Epoch {epoch} ******\n"
            .format(epoch=epoch)
        )

        training_losses = []
        mae = deque(maxlen=30)
        self.optimizer.zero_grad()
        with tqdm(total=len(loader), ncols=80) as pb:


            for batch_idx, d in enumerate(loader):
                frames, l, p = d
                frames, l, p = frames.to(device), l.to(device), p.to(device)
                frames.requires_grad = True
                self.optimizer.zero_grad()
                l_logits, p_logits = self.forward(frames)
                loss = torch.nn.BCELoss()(torch.nn.Sigmoid()(p_logits.view(-1)),p.float().view(-1))
                loss += torch.nn.CrossEntropyLoss()(l_logits.view(frames.shape[0]*64,32), l.view(-1))
                training_losses.append(loss.data.cpu().numpy())
                loss.backward()
                clip_gradient(self.optimizer, 0.1)
                self.optimizer.step()
                counts_t = []
                counts_p = []
                p_preds = torch.where(torch.nn.Sigmoid()(p_logits) > 0.5, torch.tensor(1).cuda(), torch.tensor(0).cuda())
                l_preds = l_logits.view(frames.shape[0],64,32).argmax(2)
                for l_i, p_i, ll_i, pl_i in zip(l, p, l_preds, p_preds):
                    reps = torch.where(l_i>0)[0]
                    counts_t.append( torch.sum(torch.div(p_i[reps].float(),l_i[reps].float())).data.cpu().numpy())
                    counts_p.append( torch.sum(torch.div(pl_i[reps].float(),ll_i[reps].float())).data.cpu().numpy())
                try:
                    mae_i = mean_absolute_error(np.array(counts_t), np.array(counts_p))
                    mae.append(mae_i)
                    pb.update()
                    pb.set_description(
                        f"Loss: {loss.item()}, MAE: {np.mean(mae)}")

                except Exception as e:
                    print(e)
                    continue


    def validate(self, loader):
        self.eval()
        with tqdm(total=len(loader), ncols=80) as pb:
            for batch_idx, frames, l, p in enumerate(loader):
                l_logits, p_logits = self.forward(frames)
                loss = torch.nn.CrossEntropyLoss()(l_logits, l) + torch.nn.CrossEntropyLoss()(p_logits, p)
                pb.update()
                pb.set_description(
                    "Loss: {:.4f}".format(
                        loss.item()))```

I assume “after the first batch” means that the first output and loss tensors are valid, while the second iteration produces a NaN output?
If that’s the case, could you check all gradients in the model using:

for name, param in model.named_parameters():
    print(name, torch.isfinite(param.grad).all())

after the first backward call and check, if some values are NaNs?
Alternatively, you could also set torch.autograd.set_detect_anomaly(True) at the beginning of the script, which should give you a stack trace pointing to the method, which created the NaNs in the backward pass.

1 Like

This is the code i run:

    repnet = RepNet()
    p = torch.zeros((5,64))
    l = torch.zeros((5,64))
    for i in range(2):
        dummy = torch.randn((5, 64, 3, 112, 112))
        repnet.optimizer.zero_grad()
        l_logits, p_logits = repnet(dummy)
        loss = torch.nn.BCELoss()(torch.nn.Sigmoid()(p_logits.view(-1)), p.float().view(-1))
        loss += torch.nn.CrossEntropyLoss()(l_logits.view(dummy.shape[0] * 64, 32), l.view(-1).long())
        loss.backward()
        clip_gradient(repnet.optimizer, 0.1)
        repnet.optimizer.step()
        for name, param in repnet.named_parameters():
            print(name, torch.isfinite(param.grad).all())

    

This is the output:

encoder.temporal_context.0.weight tensor(False)
encoder.temporal_context.0.bias tensor(False)
period_predictor.cnn.0.weight tensor(True)
period_predictor.cnn.0.bias tensor(True)
period_predictor.projection.weight tensor(True)
period_predictor.projection.bias tensor(True)
period_predictor.transformer.encoder.layers.0.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.encoder.layers.0.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.encoder.layers.0.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.encoder.layers.0.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.encoder.layers.0.linear1.weight tensor(True)
period_predictor.transformer.encoder.layers.0.linear1.bias tensor(True)
period_predictor.transformer.encoder.layers.0.linear2.weight tensor(True)
period_predictor.transformer.encoder.layers.0.linear2.bias tensor(True)
period_predictor.transformer.encoder.layers.0.norm1.weight tensor(True)
period_predictor.transformer.encoder.layers.0.norm1.bias tensor(True)
period_predictor.transformer.encoder.layers.0.norm2.weight tensor(True)
period_predictor.transformer.encoder.layers.0.norm2.bias tensor(True)
period_predictor.transformer.encoder.layers.1.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.encoder.layers.1.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.encoder.layers.1.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.encoder.layers.1.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.encoder.layers.1.linear1.weight tensor(True)
period_predictor.transformer.encoder.layers.1.linear1.bias tensor(True)
period_predictor.transformer.encoder.layers.1.linear2.weight tensor(True)
period_predictor.transformer.encoder.layers.1.linear2.bias tensor(True)
period_predictor.transformer.encoder.layers.1.norm1.weight tensor(True)
period_predictor.transformer.encoder.layers.1.norm1.bias tensor(True)
period_predictor.transformer.encoder.layers.1.norm2.weight tensor(True)
period_predictor.transformer.encoder.layers.1.norm2.bias tensor(True)
period_predictor.transformer.encoder.layers.2.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.encoder.layers.2.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.encoder.layers.2.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.encoder.layers.2.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.encoder.layers.2.linear1.weight tensor(True)
period_predictor.transformer.encoder.layers.2.linear1.bias tensor(True)
period_predictor.transformer.encoder.layers.2.linear2.weight tensor(True)
period_predictor.transformer.encoder.layers.2.linear2.bias tensor(True)
period_predictor.transformer.encoder.layers.2.norm1.weight tensor(True)
period_predictor.transformer.encoder.layers.2.norm1.bias tensor(True)
period_predictor.transformer.encoder.layers.2.norm2.weight tensor(True)
period_predictor.transformer.encoder.layers.2.norm2.bias tensor(True)
period_predictor.transformer.encoder.layers.3.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.encoder.layers.3.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.encoder.layers.3.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.encoder.layers.3.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.encoder.layers.3.linear1.weight tensor(True)
period_predictor.transformer.encoder.layers.3.linear1.bias tensor(True)
period_predictor.transformer.encoder.layers.3.linear2.weight tensor(True)
period_predictor.transformer.encoder.layers.3.linear2.bias tensor(True)
period_predictor.transformer.encoder.layers.3.norm1.weight tensor(True)
period_predictor.transformer.encoder.layers.3.norm1.bias tensor(True)
period_predictor.transformer.encoder.layers.3.norm2.weight tensor(True)
period_predictor.transformer.encoder.layers.3.norm2.bias tensor(True)
period_predictor.transformer.encoder.layers.4.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.encoder.layers.4.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.encoder.layers.4.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.encoder.layers.4.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.encoder.layers.4.linear1.weight tensor(True)
period_predictor.transformer.encoder.layers.4.linear1.bias tensor(True)
period_predictor.transformer.encoder.layers.4.linear2.weight tensor(True)
period_predictor.transformer.encoder.layers.4.linear2.bias tensor(True)
period_predictor.transformer.encoder.layers.4.norm1.weight tensor(True)
period_predictor.transformer.encoder.layers.4.norm1.bias tensor(True)
period_predictor.transformer.encoder.layers.4.norm2.weight tensor(True)
period_predictor.transformer.encoder.layers.4.norm2.bias tensor(True)
period_predictor.transformer.encoder.layers.5.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.encoder.layers.5.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.encoder.layers.5.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.encoder.layers.5.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.encoder.layers.5.linear1.weight tensor(True)
period_predictor.transformer.encoder.layers.5.linear1.bias tensor(True)
period_predictor.transformer.encoder.layers.5.linear2.weight tensor(True)
period_predictor.transformer.encoder.layers.5.linear2.bias tensor(True)
period_predictor.transformer.encoder.layers.5.norm1.weight tensor(True)
period_predictor.transformer.encoder.layers.5.norm1.bias tensor(True)
period_predictor.transformer.encoder.layers.5.norm2.weight tensor(True)
period_predictor.transformer.encoder.layers.5.norm2.bias tensor(True)
period_predictor.transformer.encoder.norm.weight tensor(True)
period_predictor.transformer.encoder.norm.bias tensor(True)
period_predictor.transformer.decoder.layers.0.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.0.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.0.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.0.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.0.multihead_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.0.multihead_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.0.multihead_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.0.multihead_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.0.linear1.weight tensor(True)
period_predictor.transformer.decoder.layers.0.linear1.bias tensor(True)
period_predictor.transformer.decoder.layers.0.linear2.weight tensor(True)
period_predictor.transformer.decoder.layers.0.linear2.bias tensor(True)
period_predictor.transformer.decoder.layers.0.norm1.weight tensor(True)
period_predictor.transformer.decoder.layers.0.norm1.bias tensor(True)
period_predictor.transformer.decoder.layers.0.norm2.weight tensor(True)
period_predictor.transformer.decoder.layers.0.norm2.bias tensor(True)
period_predictor.transformer.decoder.layers.0.norm3.weight tensor(True)
period_predictor.transformer.decoder.layers.0.norm3.bias tensor(True)
period_predictor.transformer.decoder.layers.1.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.1.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.1.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.1.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.1.multihead_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.1.multihead_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.1.multihead_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.1.multihead_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.1.linear1.weight tensor(True)
period_predictor.transformer.decoder.layers.1.linear1.bias tensor(True)
period_predictor.transformer.decoder.layers.1.linear2.weight tensor(True)
period_predictor.transformer.decoder.layers.1.linear2.bias tensor(True)
period_predictor.transformer.decoder.layers.1.norm1.weight tensor(True)
period_predictor.transformer.decoder.layers.1.norm1.bias tensor(True)
period_predictor.transformer.decoder.layers.1.norm2.weight tensor(True)
period_predictor.transformer.decoder.layers.1.norm2.bias tensor(True)
period_predictor.transformer.decoder.layers.1.norm3.weight tensor(True)
period_predictor.transformer.decoder.layers.1.norm3.bias tensor(True)
period_predictor.transformer.decoder.layers.2.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.2.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.2.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.2.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.2.multihead_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.2.multihead_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.2.multihead_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.2.multihead_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.2.linear1.weight tensor(True)
period_predictor.transformer.decoder.layers.2.linear1.bias tensor(True)
period_predictor.transformer.decoder.layers.2.linear2.weight tensor(True)
period_predictor.transformer.decoder.layers.2.linear2.bias tensor(True)
period_predictor.transformer.decoder.layers.2.norm1.weight tensor(True)
period_predictor.transformer.decoder.layers.2.norm1.bias tensor(True)
period_predictor.transformer.decoder.layers.2.norm2.weight tensor(True)
period_predictor.transformer.decoder.layers.2.norm2.bias tensor(True)
period_predictor.transformer.decoder.layers.2.norm3.weight tensor(True)
period_predictor.transformer.decoder.layers.2.norm3.bias tensor(True)
period_predictor.transformer.decoder.layers.3.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.3.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.3.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.3.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.3.multihead_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.3.multihead_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.3.multihead_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.3.multihead_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.3.linear1.weight tensor(True)
period_predictor.transformer.decoder.layers.3.linear1.bias tensor(True)
period_predictor.transformer.decoder.layers.3.linear2.weight tensor(True)
period_predictor.transformer.decoder.layers.3.linear2.bias tensor(True)
period_predictor.transformer.decoder.layers.3.norm1.weight tensor(True)
period_predictor.transformer.decoder.layers.3.norm1.bias tensor(True)
period_predictor.transformer.decoder.layers.3.norm2.weight tensor(True)
period_predictor.transformer.decoder.layers.3.norm2.bias tensor(True)
period_predictor.transformer.decoder.layers.3.norm3.weight tensor(True)
period_predictor.transformer.decoder.layers.3.norm3.bias tensor(True)
period_predictor.transformer.decoder.layers.4.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.4.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.4.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.4.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.4.multihead_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.4.multihead_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.4.multihead_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.4.multihead_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.4.linear1.weight tensor(True)
period_predictor.transformer.decoder.layers.4.linear1.bias tensor(True)
period_predictor.transformer.decoder.layers.4.linear2.weight tensor(True)
period_predictor.transformer.decoder.layers.4.linear2.bias tensor(True)
period_predictor.transformer.decoder.layers.4.norm1.weight tensor(True)
period_predictor.transformer.decoder.layers.4.norm1.bias tensor(True)
period_predictor.transformer.decoder.layers.4.norm2.weight tensor(True)
period_predictor.transformer.decoder.layers.4.norm2.bias tensor(True)
period_predictor.transformer.decoder.layers.4.norm3.weight tensor(True)
period_predictor.transformer.decoder.layers.4.norm3.bias tensor(True)
period_predictor.transformer.decoder.layers.5.self_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.5.self_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.5.self_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.5.self_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.5.multihead_attn.in_proj_weight tensor(True)
period_predictor.transformer.decoder.layers.5.multihead_attn.in_proj_bias tensor(True)
period_predictor.transformer.decoder.layers.5.multihead_attn.out_proj.weight tensor(True)
period_predictor.transformer.decoder.layers.5.multihead_attn.out_proj.bias tensor(True)
period_predictor.transformer.decoder.layers.5.linear1.weight tensor(True)
period_predictor.transformer.decoder.layers.5.linear1.bias tensor(True)
period_predictor.transformer.decoder.layers.5.linear2.weight tensor(True)
period_predictor.transformer.decoder.layers.5.linear2.bias tensor(True)
period_predictor.transformer.decoder.layers.5.norm1.weight tensor(True)
period_predictor.transformer.decoder.layers.5.norm1.bias tensor(True)
period_predictor.transformer.decoder.layers.5.norm2.weight tensor(True)
period_predictor.transformer.decoder.layers.5.norm2.bias tensor(True)
period_predictor.transformer.decoder.layers.5.norm3.weight tensor(True)
period_predictor.transformer.decoder.layers.5.norm3.bias tensor(True)
period_predictor.transformer.decoder.norm.weight tensor(True)
period_predictor.transformer.decoder.norm.bias tensor(True)
classifier.l_network.0.weight tensor(True)
classifier.l_network.0.bias tensor(True)
classifier.l_network.2.weight tensor(True)
classifier.l_network.2.bias tensor(True)
classifier.p_network.0.weight tensor(True)
classifier.p_network.0.bias tensor(True)
classifier.p_network.2.weight tensor(True)
classifier.p_network.2.bias tensor(True)

I’m wondering if the bottleneck I calculate in here:

        sim_mat = -torch.matrix_power(torch.cdist(features, features), 2).unsqueeze(1)

may cause this.

@ptrblck any idea on how to define this matrix bottleneck correctly? without it I’m not getting nan values.

I don’t know, why it would result in invalid values.
Could you check sim_mat for NaNs and, if found, store the inputs to this operation, so that we could have a look?