Exploding gradients in validation

Hi everyone. I’m trying to implement a video classification scheme, everything seems fine so far except one thing: exploding gradients in validation loop. I know it sounds strange because there’s not supposed to be gradients in the validation process, but that’s also what I don’t get. I’ve made sure to turn on eval() mode, and use torch.no_grad(), and somehow exploding gradients (with NaN outputs) still happens ONLY when there is a validation loop. I’ve tried commented out the validation code and the training code ran smoothly, so I figured something must be wrong with the validation code but I can’t put my hands on it. I’d really appreciate some help to point me in the right direction.

My code:

    for epoch in range(params.getint('num_epochs')):
        print('Starting epoch %i:' % (epoch + 1))
        print('*********Training*********')
        training_loss = 0
        training_losses = []
        training_progress = tqdm(enumerate(train_loader))
        artnet.train()
        for batch_index, (frames, label) in training_progress:
            training_progress.set_description('Batch no. %i: ' % batch_index)
            frames = frames.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            output = artnet.forward(frames)
            loss = criterion(output, label)
            training_loss += loss.item()

            loss.backward()
            optimizer.step()

        else:
            avg_loss = training_loss / len(train_loader)
            training_losses.append(avg_loss)
            print(f'Training loss: {avg_loss}')

        print('*********Validating*********')
        validating_loss = 0
        validating_losses = []
        validating_progress = tqdm(enumerate(validation_loader))
        artnet.eval()
        with torch.no_grad():
            for batch_index, (frames, label) in validating_progress:
                validating_progress.set_description('Batch no. %i: ' % batch_index)
                frames = frames.to(device)
                label = label.to(device)

                output = artnet.forward(frames)
                loss = criterion(output, label)

                validating_loss += loss.item()
            else:
                avg_loss = validating_loss / len(validation_loader)
                validating_losses.append(avg_loss)
                print(f'Validating loss: {avg_loss}')
        print('=============================================')
        print('Epoch %i complete' % (epoch + 1))

        if (epoch + 1) % params.getint('ckpt') == 0:
            print('Saving checkpoint...' )
            torch.save(artnet.state_dict(), os.path.join(params['ckpt_path'], 'arnet_%i' % (epoch + 1)))

        # Update LR
        scheduler.step()
    print('Training complete, saving final model....')
    torch.save(artnet.state_dict(), os.path.join(params['ckpt_path'], 'arnet_final'))
    return training_losses, validating_losses

Did you measure the individual gradient magnitudes for the loss w.r.t the network parameters? I don’t see it in the code how you meant the gradients are exploding in the validation phase. From your description it seems the training is fine i.e. loss is computed properly for training set and gradients backpropagating properly i.e. network is learning with loss reducing. But with validation your network generates nan output. By output you mean the loss that you are printing ? Or the actual output of your network which for example can be a softmax output?

Apparently, I was mistaken, the problem is something else entirely, although I’m not sure what. There is a square operation in the network that causes the output to go to infinity. The code looks like this:



def forward(self, x):
    x_rel = self.conv1(x)
    x_rel = self.bn1(x_rel)
    x_rel = x_rel ** 2
    x_rel = self.cc(x_rel)
    x_rel = self.bn2(x_rel)
    x_rel = F.relu(x_rel)
    x_app = self.conv2(x)
    x_app = self.bn3(x_app)
    x_app = F.relu(x_app)
    out = torch.cat((x_rel, x_app), dim=1)
    out = self.reduction(out)
    return out

If I comment the square operation out, no problem would occur, but since I’m trying to implement the network proposed in a paper, that would make my implementation unfaithful. If I leave the square operation as is, and comment the validation loop out, nothing will happen either. But if I leave both the square operation and the validation loop, the output of the validation loop will go to infinity, and every training loop after that will also produce infinity output.

Validation loop should not be effecting the model at all ideally.
Because you only do feed forward in that case. Make sure you are using model.eval() or torch.no_grad() contextmanager when evaluating.

You’re right, I made a mistake, it doesn’t affect the model’s weight as I thought. However, it does affect the training loop for reasons I’ve yet to identify. I do use model.eval() and torch.no_grad(), I do know the line of code that’s causing the problem, but I don’t know the reason why. Why exactly does a square op cause the output of the model to go to infinity after a few epochs, and why does that only happen if there exists a validation loop?

Batch norm uses some internal buffers to keep track of running mean and variance per channel. I am not sure why it could be happening. One thing might be to instead of using batch norm layer before and after just use only conv layer. Just a guess