I am working with multispectral images (nbands > 3) so I modified the resnet18 architecture as follows so that it can have more than 3 channels in the input layer with preloaded weights:
def get_model(arch, nbands):
input_features = 512
model = models.resnet18(pretrained=True)
if nbands > 3:
weight = model.conv1.weight.clone()
model.conv1 = torch.nn.Conv2d(nbands, 64, kernel_size=7, stride=2, padding=3, bias=False)
with torch.no_grad():
channel = 0
for i in range(nbands):
model.conv1.weight[:, i] = weight[:, channel]
channel += 1
if channel == 3:
channel = 0
model.fc = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Dropout(),
torch.nn.Linear(in_features=input_features, out_features=out_features),
torch.nn.Dropout()
)
return model
class Model(torch.nn.Module):
def __init__(self, arch, nbands):
super(Model, self).__init__()
self.arch = get_model(arch, nbands)
self.out_layer = torch.nn.Sequential(
torch.nn.Linear(in_features=out_features, out_features=1),
torch.nn.Sigmoid(),
)
def forward(self, x):
x = self.arch(x)
x = self.out_layer(x)
return x
I created a separate output layer because I want Model class to be generic (should work with VGG and other ResNet variants)
The loss function I am using is:
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.loss = nn.BCELoss(reduction='none')
def forward(self, inputs, targets):
BCE_loss = self.loss(inputs, targets)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
return torch.mean(F_loss)
This code works with most of the hyperparameters (learning rate, weight decay, alpha and gamma), but in some cases, after a few training steps the classifier predicts 0
or nan
values and execution fails. I’ve specified a few commands for success and failure cases.
Success:
python main.py --alpha=1.307306834808073 --gamma=3.978496653097758 --learning_rate=0.02564559105742597 --weight_decay=0.08126251275894147
python main.py --alpha=0.908698154033616 --gamma=3.5530891911644256 --learning_rate=0.04984724359261675 --weight_decay=0.06717551752431922
Failure:
1st Encounter:
python main.py --alpha=0.4310113846577901 --gamma=0.6493169166927948 --learning_rate=0.705326989623113 --weight_decay=0.06878194384201824
Decreased the learning rate and weight decay as suggested in some other discussions
python main.py --alpha=0.4310113846577901 --gamma=0.6493169166927948 --learning_rate=0.0000000000705326989623113 --weight_decay=0.0006878194384201824
Tried Anomaly Detection and it throws the following error
RuntimeError: Function 'AddmmBackward' returned nan values in its 2th output.
I am not sure whether this is some issue with the architecture or the loss function. I believe that this error should not be due to the dataset (satellite imagery dataset curated manually) as it is working in most cases.
[UPDATE]: I tried removing the alpha and gamma values from the failed commands (i.e. using default values for them) and then the code works fine.