I am working with multispectral images (nbands > 3) so I modified the resnet18 architecture as follows so that it can have more than 3 channels in the input layer with preloaded weights:
def get_model(arch, nbands): input_features = 512 model = models.resnet18(pretrained=True) if nbands > 3: weight = model.conv1.weight.clone() model.conv1 = torch.nn.Conv2d(nbands, 64, kernel_size=7, stride=2, padding=3, bias=False) with torch.no_grad(): channel = 0 for i in range(nbands): model.conv1.weight[:, i] = weight[:, channel] channel += 1 if channel == 3: channel = 0 model.fc = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Dropout(), torch.nn.Linear(in_features=input_features, out_features=out_features), torch.nn.Dropout() ) return model class Model(torch.nn.Module): def __init__(self, arch, nbands): super(Model, self).__init__() self.arch = get_model(arch, nbands) self.out_layer = torch.nn.Sequential( torch.nn.Linear(in_features=out_features, out_features=1), torch.nn.Sigmoid(), ) def forward(self, x): x = self.arch(x) x = self.out_layer(x) return x
I created a separate output layer because I want Model class to be generic (should work with VGG and other ResNet variants)
The loss function I am using is:
class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.loss = nn.BCELoss(reduction='none') def forward(self, inputs, targets): BCE_loss = self.loss(inputs, targets) pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss return torch.mean(F_loss)
This code works with most of the hyperparameters (learning rate, weight decay, alpha and gamma), but in some cases, after a few training steps the classifier predicts
nan values and execution fails. I’ve specified a few commands for success and failure cases.
python main.py --alpha=1.307306834808073 --gamma=3.978496653097758 --learning_rate=0.02564559105742597 --weight_decay=0.08126251275894147
python main.py --alpha=0.908698154033616 --gamma=3.5530891911644256 --learning_rate=0.04984724359261675 --weight_decay=0.06717551752431922
python main.py --alpha=0.4310113846577901 --gamma=0.6493169166927948 --learning_rate=0.705326989623113 --weight_decay=0.06878194384201824
Decreased the learning rate and weight decay as suggested in some other discussions
python main.py --alpha=0.4310113846577901 --gamma=0.6493169166927948 --learning_rate=0.0000000000705326989623113 --weight_decay=0.0006878194384201824
Tried Anomaly Detection and it throws the following error
RuntimeError: Function 'AddmmBackward' returned nan values in its 2th output.
I am not sure whether this is some issue with the architecture or the loss function. I believe that this error should not be due to the dataset (satellite imagery dataset curated manually) as it is working in most cases.
[UPDATE]: I tried removing the alpha and gamma values from the failed commands (i.e. using default values for them) and then the code works fine.