Classifier (Resnet + Custom layer) returns NaN

I am working with multispectral images (nbands > 3) so I modified the resnet18 architecture as follows so that it can have more than 3 channels in the input layer with preloaded weights:

def get_model(arch, nbands):
    input_features = 512
    model = models.resnet18(pretrained=True)
    if nbands > 3:
        weight = model.conv1.weight.clone()
        model.conv1 = torch.nn.Conv2d(nbands, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            channel = 0
            for i in range(nbands):
                model.conv1.weight[:, i] = weight[:, channel]
                channel += 1
                if channel == 3:
                    channel = 0
    model.fc = torch.nn.Sequential(
        torch.nn.Linear(in_features=input_features, out_features=out_features),
    return model

class Model(torch.nn.Module):
    def __init__(self, arch, nbands):
        super(Model, self).__init__()
        self.arch = get_model(arch, nbands)
        self.out_layer = torch.nn.Sequential(
            torch.nn.Linear(in_features=out_features, out_features=1),

    def forward(self, x):
        x = self.arch(x)
        x = self.out_layer(x)
        return x

I created a separate output layer because I want Model class to be generic (should work with VGG and other ResNet variants)

The loss function I am using is:

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.loss = nn.BCELoss(reduction='none')

    def forward(self, inputs, targets):
        BCE_loss = self.loss(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        return torch.mean(F_loss)

This code works with most of the hyperparameters (learning rate, weight decay, alpha and gamma), but in some cases, after a few training steps the classifier predicts 0 or nan values and execution fails. I’ve specified a few commands for success and failure cases.


python --alpha=1.307306834808073 --gamma=3.978496653097758 --learning_rate=0.02564559105742597 --weight_decay=0.08126251275894147
python --alpha=0.908698154033616 --gamma=3.5530891911644256 --learning_rate=0.04984724359261675 --weight_decay=0.06717551752431922

1st Encounter:

python --alpha=0.4310113846577901 --gamma=0.6493169166927948 --learning_rate=0.705326989623113 --weight_decay=0.06878194384201824

Decreased the learning rate and weight decay as suggested in some other discussions

python --alpha=0.4310113846577901 --gamma=0.6493169166927948 --learning_rate=0.0000000000705326989623113 --weight_decay=0.0006878194384201824

Tried Anomaly Detection and it throws the following error

RuntimeError: Function 'AddmmBackward' returned nan values in its 2th output.

I am not sure whether this is some issue with the architecture or the loss function. I believe that this error should not be due to the dataset (satellite imagery dataset curated manually) as it is working in most cases.

[UPDATE]: I tried removing the alpha and gamma values from the failed commands (i.e. using default values for them) and then the code works fine.